CNN sometimes does not learn at all

Hello all,

I have built the following CNN and was able to train it successfully (Validation Loss and Acc: 0.0081/0.9991).
However, I have sometimes (about 50 percent, but once it occurs always x times in a row) the problem that the CNN does not learn at all, it then remains at a loss of around 3.8 and an accuracy of about 0.02. However, since it sometimes learns, I can not explain what this could be. I have already changed the training rate, applied other optimizers or varied the batch size. But with all variations the problem can occur. Also I have tested the training/validation dataset in different splits, also without success.

As mentioned before, the problem seems to occur in series. If it does not learn with dataset a, but then learns with b, it usually learns with a afterwards. I call here each time litmodel = LitModel(hparams, mean, std, train_val_folder, test_folder), so to my knowledge actually all weights should be reset in the init area of the net. Therefore it should actually make no difference whether it has already learned or not.

Does anyone else have ideas where the problem might be?

class LitModel(pl.LightningModule):
    def __init__(self, hparams, mean, std, train_dataset, test_dataset):
        super(LitModel, self).__init__()
        self.hparams = hparams
        self.mean = mean
        self.std = std
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset

        self.activation_function_features = (
            nn.ELU()
            if self.hparams.activation_function_features == "ELU"
            else nn.ReLU()
        )
        self.activation_function_features = (
            nn.ELU()
            if self.hparams.activation_function_classifier == "ELU"
            else nn.ReLU()
        )

        self.features = nn.Sequential(
            # Hidden 1
            Stn(self.hparams.stn_parameter[0]),
            nn.Conv2d(3, 200, (7, 7), stride=1, padding=2),
            self.activation_function_features,
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(self.hparams.dropout_rate),
            nn.BatchNorm2d(200),
            # Hidden 2
            Stn(self.hparams.stn_parameter[1]),
            nn.Conv2d(200, 250, (4, 4), stride=1, padding=2),
            self.activation_function_features,
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(self.hparams.dropout_rate),
            nn.BatchNorm2d(250),
            # Hidden 3
            Stn(self.hparams.stn_parameter[2]),
            nn.Conv2d(250, 350, (4, 4,), stride=1, padding=2),
            self.activation_function_features,
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(self.hparams.dropout_rate),
            nn.BatchNorm2d(350),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(12600, 400),
            self.activation_function_features,
            nn.Linear(400, 43),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)

        return x

    def configure_optimizers(self):
        if self.hparams.optimizer == "SGD":
            optimizer = optim.SGD(
                self.parameters(),
                lr=self.hparams.learning_rate,
                momentum=self.hparams.momentum,
                nesterov=True
            )
        print("\n")
        print(optimizer)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {"train_loss": loss}
        return {"loss": loss, "log": tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        a, y_hat = torch.max(y_hat, dim=1)
        val_acc = accuracy_score(y_hat.cpu(), y.cpu())
        val_acc = torch.tensor(val_acc)

        return {"val_loss": loss, "val_acc": val_acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_val_acc = torch.stack([x["val_acc"] for x in outputs]).mean()
        print("\n")
        print("Val loss: ", avg_loss)
        print("val acc: ", avg_val_acc)
        print("\n")
        tensorboard_logs = {"val_loss": avg_loss, "avg_val_acc": avg_val_acc}
        return {"val_loss": avg_loss, "progress_bar": tensorboard_logs}

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        a, y_hat = torch.max(y_hat, dim=1)
        test_acc = accuracy_score(y_hat.cpu(), y.cpu())

        return {"test_acc": torch.tensor(test_acc)}

    def test_epoch_end(self, outputs):
        avg_test_acc = torch.stack([x["test_acc"] for x in outputs]).mean()

        tensorboard_logs = {"avg_test_acc": avg_test_acc}
        return {
            "avg_test_acc": avg_test_acc,
            "log": tensorboard_logs,
            "progress_bar": tensorboard_logs,
        }

    def train_dataloader(self):
        train_set_normal = torchvision.datasets.ImageFolder(
            root=str(Path(self.train_dataset).joinpath("train")),
            transform=transforms.Compose(
                [
                    transforms.RandomApply(
                        [
                            transforms.RandomAffine(
                                0, translate=(0.2, 0.2), resample=PIL.Image.BICUBIC
                            ),
                            transforms.RandomAffine(
                                0, shear=20, resample=PIL.Image.BICUBIC
                            ),
                            transforms.RandomAffine(
                                0, scale=(0.8, 1.2), resample=PIL.Image.BICUBIC
                            ),
                        ]
                    ),
                    transforms.ToTensor(),
                    transforms.Normalize(self.mean, self.std),
                ]
            ),
        )

        loader = DataLoader(
            train_set_normal, batch_size=50, num_workers=8, shuffle=True
        )

        return loader

    def val_dataloader(self):
        val_set_normal = torchvision.datasets.ImageFolder(
            root=str(Path(self.train_dataset).joinpath("val")),
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize(self.mean, self.std),]
            ),
        )

        val_loader = DataLoader(val_set_normal, batch_size=50, num_workers=8)

        return val_loader

    def test_dataloader(self):
        test_set_normal = torchvision.datasets.ImageFolder(
            root=str(self.test_dataset),
            transform=transforms.Compose(
                [
                    transforms.Resize((48, 48)),
                    transforms.ToTensor(),
                    transforms.Normalize(self.mean, self.std),
                ]
            ),
        )

        test_loader = DataLoader(test_set_normal, batch_size=50, num_workers=8)

        return test_loader

Hyperparameter:

hparams = {
    "dropout_rate": 0.45,
    "learning_rate": 0.001,
    "momentum": 0.9,
    "optimizer": "SGD",
    "activation_function_features": "RELU",
    "activation_function_classifier": "RELU",
    "stn_parameter": stn_params,
}

Your training seems to be instable and thus sensitive to the random seed. Sometimes using other parameter initializations might help stabilizing the training overall.

Thank you for your response!
I tried applying the initialization methods from ashunigion’s post from this stackoverflow post. Unfortunately, I couldn’t apply all of them to the Conv2D layers as well because they don’t have an in_features attribute. The ones that worked couldn’t fix the instability either though.

What else I tried is to play around with the momentum. Currently it seems to me that a lower momentum would be more stable. Are there any known alternatives to making SGD more stable with Nesterov without lowering the momentum itself?

For conv layers you could use the _calculate_fan_in_and_fan_out method, which is used in the reset_parameters() method of _ConvNd.

I don’t know, if there is a recommended way to stabilize the training besides trying out different init methods and changing other hyperparameters.

I played around with different settings today. In the end, the CNN learns stable with a momentum of 0.7 and reaches the “target” accuracy in 15 instead of 20 epochs. This is perfectly within the bounds for me, so I’ll leave it at that for now. Thanks again for your help!