I am reading the material about Automatic Mixed Precision (Automatic Mixed Precision — PyTorch Tutorials 2.1.1+cu121 documentation)
Here is some example code from the link
batch_size = 100 # Try, for example, 128, 256, 513.
in_size = 4096
out_size = 4096
num_layers = 3
num_batches = 1
epochs = 3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
def make_model(in_size, out_size, num_layers):
layers = []
for _ in range(num_layers - 1):
layers.append(torch.nn.Linear(in_size, in_size))
layers.append(torch.nn.ReLU())
layers.append(torch.nn.Linear(in_size, out_size))
return torch.nn.Sequential(*tuple(layers)).cuda()
# Creates data in default precision.
# The same data is used for both default and mixed precision trials below.
# You don't need to manually change inputs' ``dtype`` when enabling mixed precision.
data = [torch.randn(batch_size, in_size,dtype = torch.float16) for _ in range(num_batches)]
targets = [torch.randn(batch_size, out_size,dtype = torch.float16) for _ in range(num_batches)]
loss_fn = torch.nn.MSELoss().cuda()
net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
for each in net.parameters():
if each.grad is not None:
print(each.grad.dtype)
for epoch in range(1): # 0 epochs, this section is for illustration only
for input, target in zip(data, targets):
# Runs the forward pass under ``autocast``.
with torch.autocast(device_type=device, dtype=torch.float16):
output = net(input)
# output is float16 because linear layers ``autocast`` to float16.
print(output.dtype)
loss = loss_fn(output, target)
# loss is float32 because ``mse_loss`` layers ``autocast`` to float32.
print(loss.dtype)
# Exits ``autocast`` before backward().
# Backward passes under ``autocast`` are not recommended.
# Backward ops run in the same ``dtype`` ``autocast`` chose for corresponding forward ops.
loss.backward()
for each in net.parameters():
if each.grad is not None:
print(f"now grad type is {each.grad.dtype}")
opt.step()
opt.zero_grad() # set_to_none=True here can modestly improve performance
Result shows the gradients are float32
torch.float16
torch.float32
now grad type is torch.float32
now grad type is torch.float32
now grad type is torch.float32
now grad type is torch.float32
now grad type is torch.float32
now grad type is torch.float32
The comment says
#Backward ops run in the same dtype
autocast
chose for corresponding forward ops
which I think should be float16. But the final gradient is still stored in float32. So if all the backward ops runs in float16, why it cast the final grad to float32 here.