grads = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)
Is there a function like torch.nn.utils.clip_grad_norm_
for clipping the gradients in this case?
grads = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)
Is there a function like torch.nn.utils.clip_grad_norm_
for clipping the gradients in this case?
You could reuse the internal implementation of clip_grad_norm_
found here.
E.g. this should work:
# setup
x = torch.randn(2, 3, 224, 224)
model = models.resnet18().eval()
out = model(x)
loss = (out**2).mean()
loss.backward()
# create reference
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
grad_ref = [p.grad.detach().clone() for p in model.parameters()]
# autgrad.grad approach
model.zero_grad()
out = model(x)
loss = (out**2).mean()
grads = torch.autograd.grad(loss, model.parameters(), create_graph=False)
device = grads[0].device
norm_type = 2.0
max_norm = 1.0
total_norm = torch.norm(torch.stack([torch.norm(grad.detach(), norm_type).to(device) for grad in grads]), norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for grad in grads:
grad.detach().mul_(clip_coef_clamped.to(grad.device))
# compare
assert len(grad_ref) == len(grads)
for g_ref, g in zip(grad_ref, grads):
print((g_ref - g).abs().max())