Validation loss not changing when adding dropout

Hello!!

As the title says, when I add dropout to a model, the validation loss seems to stop changing (worth noting that the model finely overfits and the validation loss changes when no dropout is used). I’m using pytorch lightning, which is supposed to turn off dropout when running the validation_step.

I’m not sure about how to debug this or what to try next. I’ll leave the code I’m using and the results in hopes someone helps me find the problem.

Model definition

    def __init__(self, input_shape, dropout_p=0, *args, **kwargs):
        
        ...

        conv3 = torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3))

        if not dropout_p:
            regularization = ('batch_norm', torch.nn.BatchNorm2d(128))
        else:
            regularization = ('dropout', torch.nn.Dropout2d(dropout_p))

        layers = OrderedDict([
            ('conv', conv), ('relu', torch.nn.ReLU()), 
                ('pool',torch.nn.MaxPool2d((2, 2))), regularization,
            ('flatten', torch.nn.Flatten(start_dim=1)),
            ('output', CatDogOutput(460800))
        ])
        self.model = torch.nn.Sequential(layers)

Validation step

    def validation_step(self, batch, batch_idx):
        target, pred_target, bbox, pred_bbox, classification_loss, bbox_loss, loss = self._shared_step(batch)
        imgs = batch[0]
        sel_idx = np.random.randint(len(imgs))
        return {"img": imgs[sel_idx], "sel_idx": sel_idx, "target": target,
                "pred_target": pred_target, "pred_bbox": pred_bbox,
                "classification_loss": classification_loss, "bbox_loss": bbox_loss, "loss": loss}

    def _shared_step(self, batch):
        img, target, bbox = batch
        model_input = self.preprocess_img(img)
        pred_target, pred_bbox = self.forward_pass(model_input)
        # adds one dimension to target, just the way torch likes it
        target = target.unsqueeze(1)

        classification_loss = F.binary_cross_entropy(pred_target, target)
        bbox_loss = F.mse_loss(pred_bbox, bbox)
        loss = classification_loss + self.hparams.bbox_alpha * bbox_loss

        return target.int(), pred_target, bbox, pred_bbox, classification_loss, bbox_loss, loss

    def forward_pass(self, img):
        output = self.model(img)
        # import ipdb; ipdb.set_trace()
        return output

Loss behaviour