Exploding Loss with GradScaler and frozen Batchnorm Layers

Hi,
I today noticed that when I freeze my batchnorm2d layers and using torch.cuda.amp.GradScaler my losses exploding after just 3 or 4 batches. The same code and parameters are giving very good results with not frozen bn layers. I have to scale down the learning rate to get a functioning training process again.

This is obviously not a bug report, I just cannot come up with the reason behind this? I am using a DeepLabV3 with 107 batchnorm layers. For freezing them I simply set them to eval. Maybe someone can help me understand why this is happening.

Thanks for reading!

Could you post a minimal, executable code snippet which shows the loss explosion in amp, please?

Sure!

         learning_rate = 0.01
         weight_decay = 0.0005
         momentum = 0.9
         optimizer = optim.SGD(
                    model.parameters(), learning_rate, momentum=momentum, 
                    weight_decay=weight_decay, nesterov=True)
         grad_scaler = torch.cuda.amp.GradScaler()
         freeze_bn = True

         with model.train() as model:
            if freeze_bn:
                model.freeze_bn()
            optimizer.zero_grad(set_to_none=False)

            for batch_idx, batch in enumerate(self.train_loaders):
                batch = self._batch_to_device(batch)

                outputs = model(batch)

                targets = batch['targets']
                preds = outputs['segmentation_logits']
                loss_seg = functional.cross_entropy(preds, targets, ignore_index=255)

                if self.grad_scaler is not None:
                    # https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
                    grad_scaler.scale(loss_seg).backward()
                    grad_scaler.step(self.state.optimizer)
                    grad_scaler.update()
                    optimizer.zero_grad(set_to_none=False)

Freezing bn:

            for m in self.modules():
                if isinstance(m, (BatchNorm2d, SynchronizedBatchNorm2d)):
                    m.eval()
                    m.weight.requires_grad = False # probably redundant
                    m.bias.requires_grad = False # probably redundant

Your code is unfortunately not executable as the majority of object definitions is missing.

Sorry for that. I added most definitions. Are model and loaders definition needed too? Is something else still missing?

Edit: My guess is the explosion is because of too high learning rate. But why is it too high with frozen bn but perfectly okay without them frozen?

Yes, the model is also needed as I won’t be able to execute the code otherwise.
Your current code snippet doesn’t show an autocast usage at all, so I’m currently also unsure why you are scaling the gradients in the first place.