I am trying to debug my program that has nan gradients. How do I know which function caused the gradients to go nan when I use hooks? For example, in the snippet below, the assert is triggered at line 2. Does that mean that line 4, x_skip = self.conv3dcaps4_nj_8
is the one that resulted in nan gradients during the backprop? Because the hook for x is registered at line 2 and the gradient of x obtained when backpropagating through line 4.
1 x = self.conv2dCaps4_nj_8_strd_2(x)
2 x.register_hook(assert_enc5) # assert triggered here
3 assert not torch.isnan(x).any()
4 x_skip = self.conv3dCaps4_nj_8(x)
5 x_skip.register_hook(assert_enc4)
6 assert not torch.isnan(x_skip).any()