RuntimeError: expected scalar type Half but found Float from fc layers in TorchScript

Hi, I am getting a runtime error when using two fc layers for mixed precision training.

Environment: pytorch 1.9.0 + CUDA 11.4
Code to reproduce the error:

import torch
import matplotlib.pyplot as plt

inputs = torch.zeros([6, 152, 128, 480], dtype=torch.float32).cuda()
targets = torch.ones([6, 2], dtype=torch.float32).cuda()

class Net(torch.jit.ScriptModule):
    def __init__(self):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            torch.nn.Flatten(1),
            torch.nn.Linear(152, 32),
            torch.nn.Linear(32, 2),
        )

    @torch.jit.script_method
    def forward(self, x):
        return self.net(x)

net = Net().cuda()
loss_fn = torch.nn.MSELoss()
optimzer = torch.optim.Adam(net.parameters(), 0.001)
scaler = torch.cuda.amp.GradScaler()

for i in range(2):
    optimzer.zero_grad()
    with torch.cuda.amp.autocast():
        outputs = net(inputs)
        loss = loss_fn(outputs, targets)
    scaler.scale(loss).backward()
    scaler.step(optimzer)
    scaler.update()

Error:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<string>", line 75, in <backward op>

                weight_size = weight.size()
                grad_input = torch.matmul(grad_output, weight)
                             ~~~~~~~~~~~~ <--- HERE
                grad_weight = torch.matmul(grad_output.reshape(-1, weight_size[0]).t(), input.reshape(-1, weight_size[1]))
                # Note: calling unchecked_unwrap_optional is only safe, when we
RuntimeError: expected scalar type Half but found Float

The error happens when adding two fc layers and jitting the layers . No error when using single fc layer. Is there any way to work this around for now? Thank you!

Scripting and mixed-precision training had some issues in the past, so you could update to the latest nightly or stable release and it should work.
I just verified it in a current 1.12.0+cu116 release and don’t see any issues.

@ptrblck Thank you for your timely reply!