How to check norm of gradients?

In GAN hacks and his NIPS 2016 talk, Soumith Chintala (@smth) suggests to check that the network gradients aren’t exploding:

check norms of gradients: if they are over 100 things are screwing up

How might I do that in PyTorch?

2 Likes

The gradient for each parameter is stored at param.grad after backward. So you can use that to compute the norm.

After loss.backward(), you can check norm of gradients like this

for p in list(filter(lambda p: p.grad is not None, net.parameters())):
    print(p.grad.data.norm(2).item())
7 Likes