Torch autocast's gradient

Hi guys, I’m trying to understand how torch.amp.autocast works. The following is a minimal code example:

def see():
    model = torch.nn.Linear(10, 5).cuda()
    for p in model.parameters():
        p.data.fill_(0)
    X = torch.rand(3, 10).cuda()
    optimizer = torch.optim.SGD(model.parameters(), 0.001)
    scaler = torch.cuda.amp.GradScaler()
    with torch.cuda.amp.autocast():
        out = model(X)
        loss = out.mean()
        print(out.dtype, loss.dtype)
    scaler.scale(loss).backward()
   # output the gradient dtype
    print(next(model.parameters()).grad, next(model.parameters()).grad.dtype)
    scaler.step(optimizer)
    scaler.update()

and I noticed that the gradient’s precision is actually single-precision(FP32), which is weird. According to the paper “Mixed precision training”, shouldn’t it be FP16?

That’s expected since the gradient dtype matches the parameter dtype, which is float32 during mixed-precision training.

1 Like

So is there an FP16 version of gradient during training? or FP32 all the time.