What's the use of `scaled_grad_params` in this example of gradient penalty with scaled gradients?

Hi, I’m looking at the following example of working with gradient penalty with scaled gradients and I do not understand why do we need to compute scaled_grad_params if at the end we only need grad_params to compute the penalty?

That is, can’t we instead directly write

grad_params = torch.autograd.grad(outputs=loss,
                                  inputs=model.parameters(),
                                  create_graph=True)

original example code:

scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        with autocast(device_type='cuda', dtype=torch.float16):
            output = model(input)
            loss = loss_fn(output, target)

        # Scales the loss for autograd.grad's backward pass, producing scaled_grad_params
        scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss),
                                                 inputs=model.parameters(),
                                                 create_graph=True)

        # Creates unscaled grad_params before computing the penalty. scaled_grad_params are
        # not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:
        inv_scale = 1./scaler.get_scale()
        grad_params = [p * inv_scale for p in scaled_grad_params]

        # Computes the penalty term and adds it to the loss
        with autocast(device_type='cuda', dtype=torch.float16):
            grad_norm = 0
            for grad in grad_params:
                grad_norm += grad.pow(2).sum()
            grad_norm = grad_norm.sqrt()
            loss = loss + grad_norm

        # Applies scaling to the backward call as usual.
        # Accumulates leaf gradients that are correctly scaled.
        scaler.scale(loss).backward()

        # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

        # step() and update() proceed as usual.
        scaler.step(optimizer)
        scaler.update()

Gradient scaling is used to prevent underflows when float16 gradients are computed as described in the docs.

Hi, I understand what Gradient scaling is for, but I don’t understand why we explicitly compute scaled_grad_params if at the end we only need grad_params to compute the penalty?


I ran some tests and directly computing out grad_params give the exact same output, so this example in tutorial can be misleading. I can create a github issue about this if you agree.

setting:

import torch
import torch.nn as nn
import torch.optim as optim
from torch import autocast
from torch.cuda.amp import GradScaler

device = "cpu"
dtype = torch.bfloat16
torch.manual_seed(32)

# Toy model
class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# Toy data
inputs = torch.randn(10, 10).to(device)
targets = torch.randn(10, 1).to(device)

# Toy loss function
loss_fn = nn.MSELoss()

# Toy optimizer
model = ToyModel().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)

scaler = GradScaler()

epochs = 5

with scaled_grad_params:

for epoch in range(epochs):
    for input, target in zip(inputs, targets):
        optimizer.zero_grad()
        with autocast(device_type=device, dtype=dtype):
            output = model(input)
            loss = loss_fn(output, target)

        # torch.float16 
        scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss),
                                                 inputs=model.parameters(),
                                                 create_graph=True)

        # Creates unscaled grad_params before computing the penalty. scaled_grad_params are
        # not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:
        inv_scale = 1./scaler.get_scale()
        grad_params = [p * inv_scale for p in scaled_grad_params]

        # Computes the penalty term and adds it to the loss
        with autocast(device_type=device, dtype=dtype):
            grad_norm = 0
            for grad in grad_params:
                grad_norm += grad.pow(2).sum()
            grad_norm = grad_norm.sqrt()
            loss = loss + grad_norm

        print(f"epoch {epoch} - loss: {loss}")

        # Applies scaling to the backward call as usual.
        # Accumulates leaf gradients that are correctly scaled.
        scaler.scale(loss).backward()

        # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

        # step() and update() proceed as usual.
        scaler.step(optimizer)
        scaler.update()

print(model(torch.zeros(10)))

without scaled_grad_params:

for epoch in range(epochs):
    for input, target in zip(inputs, targets):
        optimizer.zero_grad()
        with autocast(device_type=device, dtype=dtype):
            output = model(input)
            loss = loss_fn(output, target)

        # torch.float16 
        grad_params = torch.autograd.grad(outputs=loss,
                                                 inputs=model.parameters(),
                                                 create_graph=True)

        # Computes the penalty term and adds it to the loss
        with autocast(device_type=device, dtype=dtype):
            grad_norm = 0
            for grad in grad_params:
                grad_norm += grad.pow(2).sum()
            grad_norm = grad_norm.sqrt()
            loss = loss + grad_norm
        
        print(f"epoch {epoch} - loss: {loss}")

        # Applies scaling to the backward call as usual.
        # Accumulates leaf gradients that are correctly scaled.
        scaler.scale(loss).backward()

        # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

        # step() and update() proceed as usual.
        scaler.step(optimizer)
        scaler.update()

print(model(torch.zeros(10)))

bfloat16 usage in amp does not need gradient scaling, but float16 does to prevent underflows. If your values do not suffer from underflows, you won’t see a difference if no gradient scaling is used. The scaled_grad_params are in float32 and can be safely unscaled, so the example is not misleading.

Thank you for your prompt reply. Please ignore my mistaken experiment and comment. But I still don’t understand why can’t we just write

with autocast(device_type=device, dtype=dtype):
            output = model(input)
            loss = loss_fn(output, target)

grad_params = torch.autograd.grad(outputs=loss,
                                         inputs=model.parameters(),
                                         create_graph=True)

grad_params is also float32 here.