How to known which Variable firstly got a 'nan'?

Hi, I have complex model got a ‘nan’ after several batch, with lr=0.001; only a very small lr = 0.000001 could run a full epoch. I guess it may be cause by gradient explosion . But even I set torch.nn.utils.clip_grad_norm(model.parameters(), 0.000001), it still cause loss and some weight to ‘nan’. Can somebody tell me 1) how to find out which variable is the first one that got a ‘nan’, then cause others to ‘nan’?1) how solve this kind of gradient explosion

1 Like

You could exploit the fact that nan != nan to identify the tensor that first contains a nan. To illustrate what I mean by nan != nan:

>>> import torch
>>> import numpy as np
>>> t = torch.tensor([1., 2., np.nan])
>>> t != t
 tensor([ 0,  0,  1], dtype=torch.uint8)

Then you can check in your code which variable via if (t != t).any():

for each tensor you have in the graph. It’s probably best to use a debugger for that. E.g., raise an error when if (t != t).any(): evaluates to True and then expect all the current variables.

5 Likes

Thanks for your advice! Exploited (t != t).any() and assert, I finally found it out.