You are right that model.half()
will transform all parameters and buffers to float16
, but you also correctly mentioned that h
and c
are inputs. If you do not pass them explicitly to the model, it’ll be smart enough to initialize them in the right dtype
for you in the forward
method:
model.half()
input = input.half()
out, (h, c) = model(input)
print(out.dtype, h.dtype, c.dtype)