I have a network that is dealing with some exploding gradients. I want to employ gradient clipping using torch.nn.utils. clip_grad_norm_ but I would like to have an idea of what the gradient norms are before I randomly guess where to clip.
for p in model.parameters():
param_norm = p.grad.detach().data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
Another tensor-style way is:
parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
if len(parameters) == 0:
total_norm = 0.0
else:
device = parameters[0].grad.device
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), 2.0).item()
doesn’t work in my case, cause the grad can be none,
I use:
total_norm = 0
parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
for p in parameters:
param_norm = p.grad.detach().data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
return total_norm
This works, I printed out the gradnorm and then clipped it using a restrictive clipping threshold.
I might be missing something, but calculating the norm twice isn’t necessary. It’s about 3x faster to concat all the grads into a single tensor then calculate the norm once:
grads = [
param.grad.detach().flatten()
for param in model.parameters()
if param.grad is not None
]
norm = torch.cat(grads).norm()