In GAN hacks and his NIPS 2016 talk, Soumith Chintala (@smth) suggests to check that the network gradients aren’t exploding:
check norms of gradients: if they are over 100 things are screwing up
How might I do that in PyTorch?
In GAN hacks and his NIPS 2016 talk, Soumith Chintala (@smth) suggests to check that the network gradients aren’t exploding:
check norms of gradients: if they are over 100 things are screwing up
How might I do that in PyTorch?
The gradient for each parameter is stored at param.grad
after backward. So you can use that to compute the norm.
After loss.backward(), you can check norm of gradients like this
for p in list(filter(lambda p: p.grad is not None, net.parameters())):
print(p.grad.data.norm(2).item())