Check the norm of gradients

I have a network that is dealing with some exploding gradients. I want to employ gradient clipping using torch.nn.utils. clip_grad_norm_ but I would like to have an idea of what the gradient norms are before I randomly guess where to clip.

How can I view the norms that are to be clipped?

2 Likes

Actually it seems the answer is in the code I linked to:

For a 2-norm:

        for p in model.parameters():
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
        total_norm = total_norm ** (1. / 2)
23 Likes

It is safer to detach the grad:

for p in model.parameters():
    param_norm = p.grad.detach().data.norm(2)
    total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5

Another tensor-style way is:

parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
if len(parameters) == 0:
    total_norm = 0.0
else:
    device = parameters[0].grad.device
    total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), 2.0).item()
5 Likes

in the tensor-style way: why do you use the norm of the parameter norms?

doesn’t work in my case, cause the grad can be none,
I use:

    total_norm = 0
    parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
    for p in parameters:
        param_norm = p.grad.detach().data.norm(2)
        total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    return total_norm

This works, I printed out the gradnorm and then clipped it using a restrictive clipping threshold.

Hi, milan, I don’t fully understand your question. The loop-way is also to compute the norm of parameter gradient norms, right?

yes exactly. I didn’t have any question, it’s just that your solution doesn’t work in my case

  • in case any further people run into the same problem =)

I might be missing something, but calculating the norm twice isn’t necessary. It’s about 3x faster to concat all the grads into a single tensor then calculate the norm once:

    grads = [
        param.grad.detach().flatten()
        for param in model.parameters()
        if param.grad is not None
    ]
    norm = torch.cat(grads).norm()
3 Likes