Hello there,
I’m using an optimizer (AdamW) to optimize the positions of some points in a N dimentional space, based on some grouping patterns:
At each iteration:
- A subset of those points are selected, and they should move closer.
- At the same time, the unselected points, should move far away.
- I’m using a loss that resembles the contrastive learning loss from clip.
Here an example of the performed operations:
bs = selected.size(0)
idxs = (selected >= 0)
idxs_on = (idxs)
idxs_off = (~idxs)
cnt = idxs_on.sum(dim=1)
if (cnt==0).any():
raise ValueError('cnt has 0')
vals_on = selected*idxs_on
vals_off = -(selected*idxs_off)
vals_on = (vals_on / vals_on.sum(dim=1).unsqueeze(1)).unsqueeze(2)
vals_off = (vals_off / vals_off.sum(dim=1).unsqueeze(1)).unsqueeze(2)
batch_positions = positions.expand( bs, -1, -1 )
center_on = (batch_positions * vals_on).sum(dim=1).unsqueeze(1).detach()
center_off = (batch_positions * vals_off).sum(dim=1).unsqueeze(1)
center_dist = (batch_positions - center_on)
if torch.isnan(center_dist).any():
raise ValueError('center_dist has nan')
#minimize
dist_on = (center_dist*idxs_on.unsqueeze(2)).pow(2).sum(dim=1).sqrt()
#maximize
dist_off = (center_dist*idxs_off.unsqueeze(2)).pow(2).sum(dim=1).sqrt()
#maximize
dist_center = (center_on - center_off).pow(2).sum(dim=1).sqrt()
dist_on_loss = dist_on.sum(dim=1)
dist_off_loss = torch.log(1+dist_off).sum(dim=1)
dist_center_loss = dist_center.sum(dim=1)
loss = dist_on_loss - dist_off_loss - dist_center_loss
loss.backward()
( i say an example, bacause this code operates on several batches of points and several different groups, but this is the essence of it )
After some iterations my grads goes to NaN.
I’ve tried to:
- set a very small learning rate (1e-10)
- play with the batchsize
- monitor the forward pass and look for Inf, NaN and Zeros
But the forward pass looks ok.
I added
torch.autograd.set_detect_anomaly(True)
and it gives me
RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.
Unfortunatly it always display it on the loss.backward() line
Is there a way to pinpoint in which operation exactly the grads goes to NaN during the backward pass? is there a way to place a breakpoint there?
Any input is welcome.
Thanks <3