Gradient with Automatic Mixed Precision

I am reading the material about Automatic Mixed Precision (Automatic Mixed Precision — PyTorch Tutorials 2.1.1+cu121 documentation)

Here is some example code from the link

batch_size = 100 # Try, for example, 128, 256, 513.
in_size = 4096
out_size = 4096
num_layers = 3
num_batches = 1
epochs = 3

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def make_model(in_size, out_size, num_layers):
    layers = []
    for _ in range(num_layers - 1):
        layers.append(torch.nn.Linear(in_size, in_size))
    layers.append(torch.nn.Linear(in_size, out_size))
    return torch.nn.Sequential(*tuple(layers)).cuda()

# Creates data in default precision.
# The same data is used for both default and mixed precision trials below.
# You don't need to manually change inputs' ``dtype`` when enabling mixed precision.
data = [torch.randn(batch_size, in_size,dtype = torch.float16) for _ in range(num_batches)]
targets = [torch.randn(batch_size, out_size,dtype = torch.float16) for _ in range(num_batches)]
loss_fn = torch.nn.MSELoss().cuda()

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
for each in net.parameters():
    if each.grad is not None:
for epoch in range(1): # 0 epochs, this section is for illustration only
    for input, target in zip(data, targets):
        # Runs the forward pass under ``autocast``.
        with torch.autocast(device_type=device, dtype=torch.float16):
            output = net(input)
            # output is float16 because linear layers ``autocast`` to float16.

            loss = loss_fn(output, target)
            # loss is float32 because ``mse_loss`` layers ``autocast`` to float32.

        # Exits ``autocast`` before backward().
        # Backward passes under ``autocast`` are not recommended.
        # Backward ops run in the same ``dtype`` ``autocast`` chose for corresponding forward ops.
        for each in net.parameters():
            if each.grad is not None:

                print(f"now grad type is {each.grad.dtype}")

        opt.zero_grad() # set_to_none=True here can modestly improve performance

Result shows the gradients are float32

now grad type is torch.float32
now grad type is torch.float32
now grad type is torch.float32
now grad type is torch.float32
now grad type is torch.float32
now grad type is torch.float32

The comment says

#Backward ops run in the same dtype autocast chose for corresponding forward ops

which I think should be float16. But the final gradient is still stored in float32. So if all the backward ops runs in float16, why it cast the final grad to float32 here.

The gradient will be transformed to the parameter’s dtype. I.e. by default all parameters are initialized in float32 and thus their .grad attribute will have the same dtype.

Thanks for your reply!

So is there any other reason except to make the gradient dtype keep consistent with the parameter’s dtype.

If the grad is stored in fp16 and then cast to fp32 during optimizer step, and use the _single_tensor way to update the parameter, maybe we can save some GPU memory? I don’t know whether there are CUDA ops to make operation between fp16 and fp32 directly, anyway, the gradient need to multiple the learning rate, which could be make the value to 0 if the abs(grad * learning_rate) is smaller than 2**-24 in fp16, so I agree the grad need to be cast to fp32 before they are finally applied to update the parameter weight. But if we update parameter weight with several iteration, stored grad in fp16 can save some gpu memroy.

Suppose we need 30 GB to store all fp32 grad, with fp16 it should be 15GB, suppose each time we update 1/5 of all parameter, which takes 6GB fp32 grad, so the peak memory allocated for grad should be 15 GB + 6 GB = 21 GB during 5 iterations to update all parameters.

By the way, when use optimizer like adam, adamw with AMP. Do we stills need to store the average first momentum in fp32? Since all gradient is computed with fp16 ops, even finally they are cast and stored in fp32 , this will not influence the current value. And the first momentum is related to the average gradient, Is there any other reason to use a fp32 or just to keep consistent with the fp32 gradient here?