Hi I want to use autocast with a script model and had the following error.
import torch
import torch.nn as nn
from torch.amp import autocast
class SimpleCNN(nn.Module):
def init(self):
super(SimpleCNN, self).init()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
def forward(self, x):
with autocast(device_type="cuda", enabled=True):
return self.conv1(x)
device = torch.device(‘cuda’)
Create an instance of the network
net = SimpleCNN()
net.to(device)
Create a sample input tensor
input_tensor = torch.randn(1, 3, 28, 28)
input_tensor = input_tensor.to(device)
Pass the input through the network
output = net(input_tensor)
print(output.shape)
Pass the input through the script network
script_net = torch.jit.script(net)
output2 = script_net(input_tensor)
print(output2.shape)
the error thrown out is as shown below
torch.Size([1, 6, 24, 24])
Traceback (most recent call last):
File "D:\workspace_tf\Einstein_reg_8\src\neosoft\misc\SimpleCNN.py", line 39, in <module>
output2 = script_net(input_tensor)
File "C:\Python39\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Python39\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "D:\workspace_tf\Einstein_reg_8\src\neosoft\misc\SimpleCNN.py", line 21, in forward
# x = torch.relu(self.fc1(x))
# return x
return self.conv1(x)
~~~~~~~~~~ <--- HERE
File "C:\Python39\lib\site-packages\torch\nn\modules\conv.py", line 554, in forward
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight, self.bias)
~~~~~~~~~~~~~~~~~~ <--- HERE
File "C:\Python39\lib\site-packages\torch\nn\modules\conv.py", line 549, in _conv_forward
self.groups,
)
return F.conv2d(
~~~~~~~~ <--- HERE
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
)
RuntimeError: Input type (struct c10::Half) and bias type (float) should be the same