Can you do mixed precision with just the forward pass and not the loss?

Pytorch autocast mixed precision is useful, but all of the examples show it with context over both the forward pass and the loss. Does it also work with just the context over the forward pass, and ignoring the autocast for the loss term? For code-elegance reasons, I need to exclude it in part of the loss term.

I think it should work fine e.g., a simple toy example

import torch

batch_size = 128

print("running both in autocast")
loss_fn = torch.nn.CrossEntropyLoss()
torch.manual_seed(0)
inp = torch.randn(batch_size, 1, 28, 28, device='cuda')
target = torch.randint(0, 4, (batch_size,), device='cuda')
c = torch.nn.Conv2d(1, 4, 3, device='cuda')
with torch.autocast('cuda'):
  o = torch.mean(c(inp), (2, 3))
  loss = loss_fn(o, target)
print(loss)
loss.backward()
print(c.weight.grad)

print("running loss outside of autocast")
loss_fn = torch.nn.CrossEntropyLoss()
torch.manual_seed(0)
inp = torch.randn(batch_size, 1, 28, 28, device='cuda')
target = torch.randint(0, 4, (batch_size,), device='cuda')
c = torch.nn.Conv2d(1, 4, 3, device='cuda')
with torch.autocast('cuda'):
  o = torch.mean(c(inp), (2, 3))
loss = loss_fn(o.float(), target)
print(loss)
loss.backward()
print(c.weight.grad)
running both in autocast
tensor(1.4228, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor([[[[ 5.9128e-04,  1.0681e-03,  8.3590e-04],
          [ 4.1270e-04,  1.0691e-03,  8.1015e-04],
          [-3.1173e-05,  5.8031e-04,  3.9482e-04]]],


        [[[-2.8467e-04, -5.3453e-04,  1.1891e-04],
          [ 1.7011e-04, -1.2958e-04,  3.4952e-04],
          [ 6.0225e-04,  2.3711e-04,  5.9605e-04]]],


        [[[-2.9230e-04, -3.6621e-04, -5.6601e-04],
          [-3.0875e-04, -4.0889e-04, -5.2071e-04],
          [-5.1451e-04, -4.6921e-04, -6.7139e-04]]],


        [[[-9.8944e-06, -1.6201e-04, -3.8552e-04],
          [-2.7370e-04, -5.3024e-04, -6.4039e-04],
          [-5.6267e-05, -3.4690e-04, -3.2020e-04]]]], device='cuda:0')
running loss outside of autocast
tensor(1.4228, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor([[[[ 5.9128e-04,  1.0662e-03,  8.3447e-04],
          [ 4.1318e-04,  1.0681e-03,  8.0919e-04],
          [-3.0875e-05,  5.7936e-04,  3.9387e-04]]],


        [[[-2.8634e-04, -5.3549e-04,  1.1861e-04],
          [ 1.6904e-04, -1.2982e-04,  3.4976e-04],
          [ 6.0081e-04,  2.3639e-04,  5.9605e-04]]],


        [[[-2.9230e-04, -3.6621e-04, -5.6601e-04],
          [-3.0875e-04, -4.0889e-04, -5.2071e-04],
          [-5.1451e-04, -4.6921e-04, -6.7139e-04]]],


        [[[-7.3910e-06, -1.5986e-04, -3.8457e-04],
          [-2.6965e-04, -5.2691e-04, -6.3753e-04],
          [-5.2571e-05, -3.4356e-04, -3.1829e-04]]]], device='cuda:0')

However, the casting rules might be different or unexpected outside of an autocast context, so I would double check that e.g., the input and output types to your loss function are what you expected.

Is there a reason why this line is enforced?

if optimizer_state["stage"] is OptState.STEPPED:
            raise RuntimeError("step() has already been called since the last update().")

It seems like stepping the same optimizer twice after two independent backward passes wouldn’t be an issue.

Is updating the _scale absolutely necessary algorithmically?