How to clip grad norm grads from torch.autograd.grad

grads = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)

Is there a function like torch.nn.utils.clip_grad_norm_ for clipping the gradients in this case?

You could reuse the internal implementation of clip_grad_norm_ found here.
E.g. this should work:

# setup
x = torch.randn(2, 3, 224, 224)
model = models.resnet18().eval()
out = model(x)
loss = (out**2).mean()
loss.backward()

# create reference
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

grad_ref = [p.grad.detach().clone() for p in model.parameters()]

# autgrad.grad approach
model.zero_grad()
out = model(x)
loss = (out**2).mean()

grads = torch.autograd.grad(loss, model.parameters(), create_graph=False)

device = grads[0].device
norm_type = 2.0
max_norm = 1.0
total_norm = torch.norm(torch.stack([torch.norm(grad.detach(), norm_type).to(device) for grad in grads]), norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for grad in grads:
    grad.detach().mul_(clip_coef_clamped.to(grad.device))

# compare
assert len(grad_ref) == len(grads)
for g_ref, g in zip(grad_ref, grads):
    print((g_ref - g).abs().max())
2 Likes