How to pinpoint NaN grads?

Hello there,

I’m using an optimizer (AdamW) to optimize the positions of some points in a N dimentional space, based on some grouping patterns:

At each iteration:

  • A subset of those points are selected, and they should move closer.
  • At the same time, the unselected points, should move far away.
  • I’m using a loss that resembles the contrastive learning loss from clip.

Here an example of the performed operations:

        bs = selected.size(0)
        
        idxs = (selected >= 0)
        idxs_on = (idxs)
        idxs_off = (~idxs)

        cnt = idxs_on.sum(dim=1)
        if (cnt==0).any():
            raise ValueError('cnt has 0')

        vals_on = selected*idxs_on
        vals_off = -(selected*idxs_off)

        vals_on = (vals_on / vals_on.sum(dim=1).unsqueeze(1)).unsqueeze(2)
        vals_off = (vals_off / vals_off.sum(dim=1).unsqueeze(1)).unsqueeze(2)

        batch_positions = positions.expand( bs, -1, -1 )

        center_on = (batch_positions * vals_on).sum(dim=1).unsqueeze(1).detach()
        center_off = (batch_positions * vals_off).sum(dim=1).unsqueeze(1)
        
        center_dist = (batch_positions - center_on)

        if torch.isnan(center_dist).any():
            raise ValueError('center_dist has nan')

        #minimize
        dist_on = (center_dist*idxs_on.unsqueeze(2)).pow(2).sum(dim=1).sqrt()
        #maximize
        dist_off = (center_dist*idxs_off.unsqueeze(2)).pow(2).sum(dim=1).sqrt()
        #maximize
        dist_center = (center_on - center_off).pow(2).sum(dim=1).sqrt()

        dist_on_loss = dist_on.sum(dim=1)
        dist_off_loss = torch.log(1+dist_off).sum(dim=1)
        dist_center_loss  = dist_center.sum(dim=1)
        

        loss = dist_on_loss - dist_off_loss - dist_center_loss
        
        loss.backward()

( i say an example, bacause this code operates on several batches of points and several different groups, but this is the essence of it )

After some iterations my grads goes to NaN.

I’ve tried to:

  • set a very small learning rate (1e-10)
  • play with the batchsize
  • monitor the forward pass and look for Inf, NaN and Zeros
    But the forward pass looks ok.

I added

torch.autograd.set_detect_anomaly(True)

and it gives me

RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.

Unfortunatly it always display it on the loss.backward() line

Is there a way to pinpoint in which operation exactly the grads goes to NaN during the backward pass? is there a way to place a breakpoint there?
Any input is welcome.

Thanks <3

It turned out that in my pow(2) methods, during the forward pass, I has a bunch of near to 0 vars.

While 0^2, gives just a 0, a near-to-0 values, leads to all sort of numerical instabilities.

Still I really would like to find a way to debug NaN cards more effectively.
No other way?

Is there a way to pinpoint in which operation exactly the grads goes to NaN during the backward pass?

If you scroll up the anomaly detection message should show you a stack trace of which operation during forward corresponds to the problematic PowBackward0.

1 Like