Pytorch model FP32 to FP16 using half()- LSTM block is not casted

You are right that model.half() will transform all parameters and buffers to float16, but you also correctly mentioned that h and c are inputs. If you do not pass them explicitly to the model, it’ll be smart enough to initialize them in the right dtype for you in the forward method:

model.half()
input = input.half()

out, (h, c) = model(input)
print(out.dtype, h.dtype, c.dtype)
1 Like