ZeroDivisionError and Loss goes to NaN with Apex Loss Scaling

I am using an open source distributed PyTorch implementation of training AlexNet from scratch on ImageNet (https://github.com/richardkxu/distributed-pytorch).

This implementation works flawlessly as is. As soon as I add an additional loss (loss_contrastive) in the following manner:

loss = criterion(output, target)
loss_contrastive = getContrastiveLoss(target, rep3, rep4, rep5, contrastive_idxs)
loss += 0.1*loss_contrastive

optimizer.zero_grad()

# Mixed-precision training requires that the loss is scaled in order
# to prevent the gradients from underflow
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()

I get a ZeroDivisionError on the last line. Also, I am getting a gradient overflow error for many consecutive steps (Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 5e-324) and looking at the two losses, both losses separately start at around ~10, and then loss_contastive begins rapidly increasing. After many steps of loss_contrastive being at around ~10^8 and many gradient overflows (here the original loss is ~50), both losses become NaNs.

For context, loss_contrastive is simply a contrastive MSE loss which aims to minimize the distance between certain representations for certain inputs and maximize it for others. Am I treating the addition of a new loss incorrectly? Any ideas what might be causing this?

Thanks!

This doesn’t seem right. Are you expecting the loss to have such a high value?
If the loss gets NaN values, could you check, if your model is also outputting NaNs?

The zero division error is raised, as the loss scaler detects NaN gradients (due to an overflow, which might sometimes happen and is expected, or due to the NaN loss, which should never happen), and thus reduces the scaling factor until it’s not representable anymore.

That being said, we recommend to use the native amp implementation instead of apex.

Hi @ptrblck,

Thanks so much for your response. I decreased the learning rate by an order of magnitude as per someone’s suggestion on a related GitHub Issue, and this seemed to resolve the problem! And thank you for your suggestion – I will take a look into the native amp implementation.

I do have a quick follow-up though: Given that the original loss trained well with a learning rate of 0.01, and the new loss trains well with a different learning rate of 0.001, I thought it would make sense to use two separate optimizers in the following manner:

optimizer = torch.optim.SGD(model.parameters(), 
                            0.01, 
                            momentum=args.momentum, 
                            weight_decay=args.weight_decay)
parameters_contrastive = list(model.conv2.parameters()) + 
                         list(model.maxpool2.parameters()) + 
                         list(model.conv3.parameters())
optimizer_contrastive = torch.optim.SGD(parameters_contrastive, 
                                        0.001,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)

model, [optimizer, optimizer_contrastive] = amp.initialize(model, 
                                                           [optimizer, optimizer_contrastive],
                                                           opt_level=args.opt_level,
                                                           keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                                                           loss_scale=args.loss_scale,
                                                           num_losses=2)

and in the train loop:

output, rep3, rep4, rep5 = model(input)

loss = criterion(output, target)

# compute contrastive loss
loss_contrastive = getContrastiveLoss(target, rep3, rep4, rep5, contrastive_idxs)
loss_contrastive *= 0.1

# compute gradient and do SGD step
optimizer.zero_grad()
optimizer_contrastive.zero_grad()

# Mixed-precision training requires that the loss is scaled in order
# to prevent the gradients from underflow
with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
    scaled_loss.backward(retain_graph=True)
with amp.scale_loss(loss_contrastive, optimizer_contrastive, loss_id=1) as scaled_loss:
    scaled_loss.backward()

optimizer.step()
optimizer_contrastive.step()

While both losses (loss and loss_contrastive) start off at the same order of magnitude (after loss_contrastive is scaled by 0.1), loss_contrastive is reduced very smoothly throughout training while loss remains pretty much constant. Again, just to reiterate from my original post, when this exact model is trained without the additional loss_contrastive, the loss is reduced very smoothly and the AlexNet is able to achieve a final top-5 accuracy of 80% on ImageNet.

I am unsure how the addition of my loss_contrastive is discouraging the network from optimizing the original, main loss. Do you have any idea why this might be the case, or what I can do to combat this?

Thanks once again for your help; You have helped me numerous times in the past as well and it always stuns me to see how people of your calibre are so generous and accessible even to beginners. I truly appreciate all your effort.

This is pure speculation, but I think that in your use case where both losses are used, the gradients of loss_contrastive might not yield a proper signal and in fact your model mainly optimizes for the loss.
This might happen, if the gradient magnitude is different, so that the gradients of loss_contrastive could be seen as noise compared to the gradients of loss.
If that’s the case, the optimization might lower one loss and unfortunately increase the other one.
You could try to use different weightings on the losses directly to create an equal range of their values and rerun the code again.

Thanks for the kind words and I’m glad I could help. :slight_smile: