Batchnorm2d outputs NaN - Negative running_var

Hi,

I’m trying to understand and solve a problem where my loss goes to nan. Information I have:

  • Fp16 training (autocast, scale().backward, unscale, clip_grad_norm, scaler.step, scaler.update, zerograd) diverges to Nan
  • I found the issue in a batchnorm layer during an fp32 inference
  • It goes: convolution2d > x > batchnorm2d > some feature maps are full of NaN

After checking in depth (manually recomputing batchnorm), I found that some of my batchnorm.running_var have negative values which isn’t supposed to happen (it produces nan when the batchnorm applies sqrt() on these values).

How could the values in running_var be negative?

Does it mean that there’s an issue with the scaler during the training?

Thanks to anyone with more information on this issue

This is indeed strange. Did you see the negative var values after the mixed-precision training when you tried to evaluate the model in FP32?
I assume you didn’t set a negative momentum or anything like that?

Momentum was 0.1 and eps was 1e-5.
The negative var values appear during the training so I guess that even with the regular fp16 training procedure there are still ways to have an overflow here.

After I changed my model I have more stable trainings. I torch.clamp the output and the input of the model, I also apply tanh after each convolution such that having big values here or in the batchnorm would be useless.

Apart from an overflow I don’t see how I could get a negative variance in the batchnorm

Batchnorm layers are running in FP32 in an autocast region, so they shouldn’t overflow internally.
Just to make sure I’m understanding your use case: “regular fp16 training” means you are using torch.cuda.amp or are you calling manually half() on any modules?

No no I’m not doing half() anywhere, It means I’m using the procedure from here: Automatic Mixed Precision examples — PyTorch 1.8.0 documentation.

My code:

Main while loop:
    fetching batch from dataloader
    batch['x'] = batch['x'].cuda(non_blocking=True)

    #infer
    with torch.cuda.amp.autocast():
        out = model(batch['x'])
        loss = loss_model(out, batch['x']) #autoencoder loss
    
    #backprop
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

batch[‘x’] is of shape [64, 1, 256, 150] with values ranging from 0 to 1e4.
I know the troubles with manually using half() on a model

Which batch norm are you getting (i.e. what is the grad_fn of the batch norm output, it could be cudnn or native)?

The grad_fn for the output just after the batchnorm is CudnnBatchNormBackward

OK. Then it’s at least not my code doing it. :wink: I had tried to see if I could get the Welford variance update to be negative, but it seems to not be easy (you would need x - avg and x-(avg + (x-avg)/n) to be of different signs).

More seriously, you could try to see if you can avoid the problem by calling batch norm in a with torch.backends.cudnn.flags(enabled=False): block (but I didn’t try if it works or helps).

Best regards

Thomas

Thanks, I’ll try that if I have this error again.

As I said, It seems that my training is more stable after I changed some tweaks in my model so I guess I just had an unstable architecture.
It’s weird that a variance could be negative so maybe it’ll be useful to fix that if someone else have the same problem later or if it doesn’t make sense in the existing code.

My architecture was a convolutional autencoder, 2dmap > conv2d(bn2d/act) > conv2d > … > avgpool2d > features > bn1d > linear > bn1d > convT2d > … > 2dmap. Removing the two 1D-BN in the middle maybe solved my issue, but I’m not sure and of course it’s weird because the negative variances are in the first 2D-BN (I’m leaving that here in case someone else have the same issue later)

I don’t think the error was caused by the architecture, but as @tom mentioned, might be a faulty kernel in cudnn, which should be fixed. In case you have an executable code snippet or run into this error again, could you please post it here?

I’m running into this issue as well. I’ve debugged it a bit by dumping the input data & network state_dict as soon as a NaN loss is detected.

When I step through the network with that data a specific Conv2d produces an output that has some inf elements. The Conv2d’s weights & bias appear reasonable:

 |   mean   |   min    |   max    |   std    |       shape        || param_name  
 |    0.208 |   -1.560 |    9.680 |    0.997 |    [64, 128, 3, 3] || weight
 |    0.662 |   -2.702 |    8.658 |    1.713 |               [64] || bias

My guess is that the inf elements lead to NaNs when incorporated into the BatchNorm2d’s running stats.

I have dumped & can share the specific inputs & state dicts for the Conv2d & BatchNorm2d if they would be of any help in debugging this, but the BatchNorm2d has already been polluted and has a NaN entry in each of running_mean and running_var - seems like there’s first an inf loss, then comes the NaN, though I’ve not confirmed that hypothesis.

Ideally the BatchNorm2d layer would not update its running stats in AMP training with inf inputs, so that the scaler can fully roll back the iteration when inf is detected.

(Maybe it already does that, and the issue is more complicated, idk)

Confirmed that BatchNorm2d’s running stats do indeed get ruined when it receives inf input.

Filed a bug for this: BatchNorm2d doesn't handle inf input in AMP training · Issue #90342 · pytorch/pytorch · GitHub

Invalid activation values during the forward pass are not expected during mixed-precision training as explained in your created issue and you would thus need to make sure your training isn’t exploding. The GradScaler is responsible for gradient scaling and cannot rewind the forward pass.

invalid activation values can happen sometimes. unfortunately, there is no protection against updates of running-stats against bad values, and no way to cancel the update. what might be a good pattern is to e.g. cache previous iteration of running_stats and have a method to recover them (e.g. if the loss is large and we’d like to skip the batch)