How to monitor grad in self-made function (consisted of formal torch.Tensor) during backward() process?

I am troubled with debugging Nan.

The easiest way to debug it is to monitor the model parameter grad (for example using hook).

However, all parameters are updated to nan at the same time, and it is difficult to identify the cause by the usual method.

In my model, loss is calculated by applying self-made function consisted of several formal torch.Tensor for output of model defined by nn.

For that update, we perform backpropagation with backward () by Autograd.

I doubt that nan is occurring during the backward and want to identify the cause.
(Nan does not occur in the process of inference normally.)

So, I would like to know how to monitor the propagation of grad in my own function composed of several torch.Tensors (print each value of grad).

As you said, you could register backward hooks to a functions which prints, if a nan was detected.

The hook body could look like this:

grad_output = torch.tensor([1.0, np.nan, 2.0])
...
def printgradnan(self, grad_input, grad_output):
    if (grad_output!=grad_output).any():
        print('NaN detected!')

Thank you for replying!
I thought this register_backward_hook function is for visualizing nn package, which can’t be used for just a variable (whose requires_grad is True).

I solved this problem by using register_hook function from this discussion.

I implemented like following

def save_grad(name):
    def hook(grad):
        print(name, ":", grad)
    return hook

a = torch.ones((3),requires_grad=True)
b = 3*a
c = b*a
loss = c.sum()
b.register_hook(save_grad('b'))
loss.backward()

I checked NaN using your Nan detection code for each printed grad value.