Hi, I’m looking at the following example of working with gradient penalty with scaled gradients and I do not understand why do we need to compute scaled_grad_params
if at the end we only need grad_params
to compute the penalty?
That is, can’t we instead directly write
grad_params = torch.autograd.grad(outputs=loss,
inputs=model.parameters(),
create_graph=True)
original example code:
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
# Scales the loss for autograd.grad's backward pass, producing scaled_grad_params
scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss),
inputs=model.parameters(),
create_graph=True)
# Creates unscaled grad_params before computing the penalty. scaled_grad_params are
# not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:
inv_scale = 1./scaler.get_scale()
grad_params = [p * inv_scale for p in scaled_grad_params]
# Computes the penalty term and adds it to the loss
with autocast(device_type='cuda', dtype=torch.float16):
grad_norm = 0
for grad in grad_params:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm.sqrt()
loss = loss + grad_norm
# Applies scaling to the backward call as usual.
# Accumulates leaf gradients that are correctly scaled.
scaler.scale(loss).backward()
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
# step() and update() proceed as usual.
scaler.step(optimizer)
scaler.update()