Pytorch model FP32 to FP16 using half()- LSTM block is not casted

I am trying to convert some pre-trained models to half precision for deployment.

I tried

model = model.half()

and it could cast the basic FP32 CNN model into FP32 well
but not LSTM model

I got this error when FP16 input was passed to torch.nn.LSTM layer after I called model.half()

  File "/~env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/~env/lib/python3.9/site-packages/torch/nn/modules/rnn.py", line 769, in forward
    result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
RuntimeError: Input and hidden tensors are not the same dtype, found input tensor with Half and hidden tensor with Float

When I read the docs,
half function can cast all floating point parameters and buffers to half datatype. Right? This method modifies the module in-place.

But why LSTM layer cannot be casted? and How can I cast it to half datatype?

It works for me:

# float32
model = nn.LSTM(10, 20, 2).cuda()
input = torch.randn(5, 3, 10).cuda()
h0 = torch.randn(2, 3, 20).cuda()
c0 = torch.randn(2, 3, 20).cuda()

out, (h, c) = model(input, (h0, c0))
print(out.dtype, h.dtype, c.dtype)
# torch.float32 torch.float32 torch.float32

# float16
model.half()
input = input.half()
h0 = h0.half()
c0 = c0.half()

out, (h, c) = model(input, (h0, c0))
print(out.dtype, h.dtype, c.dtype)
# torch.float16 torch.float16 torch.float16

and based on the error message I would assume you are not casting all inputs to float16.

1 Like

Thank you for the answer and the example!

I expected model.half() could convert all the parameters and modules in the model into FP16.
And h0, c0 are defined in that model.
But they are also inputs to the LSTM layer. I didn’t convert h0, c0 into FP16 manually and that’s why I got the above error.

You are right that model.half() will transform all parameters and buffers to float16, but you also correctly mentioned that h and c are inputs. If you do not pass them explicitly to the model, it’ll be smart enough to initialize them in the right dtype for you in the forward method:

model.half()
input = input.half()

out, (h, c) = model(input)
print(out.dtype, h.dtype, c.dtype)
1 Like

Thank you for the further information!

I had updated the model to have a if-condition flow to convert h0, c0 types before I saw your second comment.

But now I updated the model with your second suggestion, and yes it becomes smart enough to initialize them in the right dtype.

I was planning to trace the model and your answer helped me not to include condition flow.
Thank you :slight_smile:

1 Like