NaN loss and grad encoutered at first epoch

Encounter Gradient overflow and the model performance are really weird.

At about 1600 steps, the Mask language modeling loss became NaN, and after a few more steps everything crashed down to NaN. At first, I think it was a trivial coding problem and after a week of debugging I can’t really figure out how this occurs.

After utilizing torch.autograd.detect_anomaly it returns LogBackward returns with “Function ‘LogBackward0’ returned nan values in its 0th output.” I’ve already added a large eps=1e-3 to torch.log so this shouldn’t really happen, weird.

Things that have tried:

  • Change the AMP from Apex to torch official ;
  • Self-made Cross Entropy Loss with larger eps to fit fp16 dynamic range ;
  • Fit with lower learning rate (from 1e-4 to 5e-5 to 1e-5) and also lower multiple ratios (for new layers);
  • Narrow down the interval of bp scaling (from 32768 to 256) ;
  • Utilize gradient clipping (unscaled gradient to 1.0) ;
  • Check the data for invalid input.

Interesting fact:

I found that the minimum scaling threshold will affect when the model fails. Code snippet:

if scaler._scale < min_thresh:
    scaler._scale = torch.tensor(min_thresh).to(scaler._scale)

when min_thresh is set to 256, it crashes at about 600 steps; to 128, 900 steps; 64, 1150 steps; 32, 1450 steps; not setting, 1650 steps (all under 5e-5 learning rate).

Hope someone can give me some further advice :smiling_face_with_tear: Thanks a lot

Note that it is expected to see invalid gradients in a few iterations when amp in float16 is used. If such an iteration is seen, the GradScaler will reduce its internal scaling factor and will skip the optimizer.step() call so that no parameters will be updated with these invalid gradients. It is not expected to see invalid outputs computed in the forward pass and if you are seeing such a behavior I would recommend to try to narrow down where exactly this invalid output is created.

Thanks for your kind advice and sorry for my late update.

It is a dual-stream model, XLM-roberta + BERT with no cross attention. I utilized forward_hook to check the forward process as well as gradients and parameters.

The parameters and the forward process seem to be fine now, except for some Infinity gradients detected at the contrastive loss temperature, the model’s gradient seems to gradually fail from superficial to deep layers. At first some NaN values were reported to be found at like roberta_encoder.layer.0.query.weight and some embedding layers, after a couple of hundred steps, NaN values were caught in deeper layers like roberta_encoder.layer.3.dense.weight, finally the whole model’s gradient became NaN, I can’t figure out how this could happen since the model works just fine in FP32 and a lot of efforts has been made to prevent any number getting out of fp16’s range.

Any further ideas?
Thanks again!

No, I don’t have further ideas as it’s still unclear if the observed NaN gradients are expected and caused by a high loss scaling factor or if something else causes them. Your follow up just explains that you are still observing invalid gradients, which might be expected.

Oh, maybe I misunderstood your early advice :frowning:

I have tried not to set the minimum scaling factor and the model just kept on skipping steps and shrinking the scaling factor until an error is raised. I also tried to set the init_scale to 2048.0 or smaller but it seems not working. Is this what you advise me to do?

Thanks a lot.

I finally figured out the problem.

Putting the division of math.sqrt(self.head_dim) after Query instead of after the matmul of Query and Key fixes the bug.

q = self.split_heads(self.q_proj(x)) / math.sqrt(self.head_dim)
k = self.split_heads(self.k_proj(x))
v = self.split_heads(self.v_proj(x))

attn_score = torch.matmul(q, k.transpose(-1, -2))
attn_probs = self.softmax(attn_score)

instead of

q = self.split_heads(self.q_proj(x))
k = self.split_heads(self.k_proj(x))
v = self.split_heads(self.v_proj(x))

attn_score = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn_probs = self.softmax(attn_score)

Do you know why this causes gradient overflow? I can’t figure out the reason, thanks.

1 Like