TorchScript model doesn't work with autocast

Hi I want to use autocast with a script model and had the following error.

import torch
import torch.nn as nn
from torch.amp import autocast

class SimpleCNN(nn.Module):
def init(self):
super(SimpleCNN, self).init()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)

def forward(self, x):
    with autocast(device_type="cuda", enabled=True):
        return self.conv1(x)

device = torch.device(‘cuda’)

Create an instance of the network

net = SimpleCNN()
net.to(device)

Create a sample input tensor

input_tensor = torch.randn(1, 3, 28, 28)
input_tensor = input_tensor.to(device)

Pass the input through the network

output = net(input_tensor)
print(output.shape)

Pass the input through the script network

script_net = torch.jit.script(net)
output2 = script_net(input_tensor)
print(output2.shape)

the error thrown out is as shown below

torch.Size([1, 6, 24, 24])
Traceback (most recent call last):
  File "D:\workspace_tf\Einstein_reg_8\src\neosoft\misc\SimpleCNN.py", line 39, in <module>
    output2 = script_net(input_tensor)
  File "C:\Python39\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Python39\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "D:\workspace_tf\Einstein_reg_8\src\neosoft\misc\SimpleCNN.py", line 21, in forward
            # x = torch.relu(self.fc1(x))
            # return x
            return self.conv1(x)
                   ~~~~~~~~~~ <--- HERE
  File "C:\Python39\lib\site-packages\torch\nn\modules\conv.py", line 554, in forward
    def forward(self, input: Tensor) -> Tensor:
        return self._conv_forward(input, self.weight, self.bias)
               ~~~~~~~~~~~~~~~~~~ <--- HERE
  File "C:\Python39\lib\site-packages\torch\nn\modules\conv.py", line 549, in _conv_forward
                self.groups,
            )
        return F.conv2d(
               ~~~~~~~~ <--- HERE
            input, weight, bias, self.stride, self.padding, self.dilation, self.groups
        )
RuntimeError: Input type (struct c10::Half) and bias type (float) should be the same

Try to wrap autocast around the forward pass as shown in the docs and your model should work. TorchScript is also in maintenance mode and we recommend using torch.compile instead.

Thank you for the reply. I would like to use autocast partially, as my model uses PCA, which doesn’t support half precision. I will try torch.compile.