Pytorch lightning + 3D resnet fails to learn on small image size

I’m working on a binary classification problem with a custom dataset. I’ve trained 78 epochs of a 3D resnet on 1x8x8x8 images (C, T, H, W) within a pytorch lightning wrapper, but the model fails to learn. The total dataset size is 480. I can optionally double the dataset size by adding the same 3D frames with greater noise added, but I’m not sure if that would be helpful. The dataset is synthetic, so the validation set is from the same distribution as the training set.

I think my next step is to make the model larger, but I’d appreciate advice on 1) whether there are any obvious bugs in my current code that I’ve overlooked (I’m new to pytorch and resnets, and may have made a dumb error), and 2) the best way to make my current model larger. I’ve found the relatively small image size to be challenging as most tutorials demonstrate 3D resnets on ~128x128x128 image sizes or similar.

Here is the model I’m currently using:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1           [-1, 4, 8, 8, 8]             504
       BatchNorm3d-2           [-1, 4, 8, 8, 8]               8
              ReLU-3           [-1, 4, 8, 8, 8]               0
            Conv3d-4           [-1, 4, 8, 8, 8]             436
       BatchNorm3d-5           [-1, 4, 8, 8, 8]               8
              ReLU-6           [-1, 4, 8, 8, 8]               0
            Conv3d-7           [-1, 4, 8, 8, 8]             436
       BatchNorm3d-8           [-1, 4, 8, 8, 8]               8
              ReLU-9           [-1, 4, 8, 8, 8]               0
    ResidualBlock-10           [-1, 4, 8, 8, 8]               0
           Conv3d-11           [-1, 8, 4, 4, 4]             872
      BatchNorm3d-12           [-1, 8, 4, 4, 4]              16
             ReLU-13           [-1, 8, 4, 4, 4]               0
           Conv3d-14           [-1, 8, 4, 4, 4]           1,736
      BatchNorm3d-15           [-1, 8, 4, 4, 4]              16
           Conv3d-16           [-1, 8, 4, 4, 4]              40
      BatchNorm3d-17           [-1, 8, 4, 4, 4]              16
             ReLU-18           [-1, 8, 4, 4, 4]               0
    ResidualBlock-19           [-1, 8, 4, 4, 4]               0
           Conv3d-20           [-1, 8, 4, 4, 4]           1,736
      BatchNorm3d-21           [-1, 8, 4, 4, 4]              16
             ReLU-22           [-1, 8, 4, 4, 4]               0
           Conv3d-23           [-1, 8, 4, 4, 4]           1,736
      BatchNorm3d-24           [-1, 8, 4, 4, 4]              16
             ReLU-25           [-1, 8, 4, 4, 4]               0
    ResidualBlock-26           [-1, 8, 4, 4, 4]               0
           Conv3d-27          [-1, 16, 2, 2, 2]           3,472
      BatchNorm3d-28          [-1, 16, 2, 2, 2]              32
             ReLU-29          [-1, 16, 2, 2, 2]               0
           Conv3d-30          [-1, 16, 2, 2, 2]           6,928
      BatchNorm3d-31          [-1, 16, 2, 2, 2]              32
           Conv3d-32          [-1, 16, 2, 2, 2]             144
      BatchNorm3d-33          [-1, 16, 2, 2, 2]              32
             ReLU-34          [-1, 16, 2, 2, 2]               0
    ResidualBlock-35          [-1, 16, 2, 2, 2]               0
AdaptiveAvgPool3d-36          [-1, 16, 1, 1, 1]               0
           Linear-37                    [-1, 1]              17
================================================================
Total params: 18,257
Trainable params: 18,257
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.23
Params size (MB): 0.07
Estimated Total Size (MB): 0.30
----------------------------------------------------------------

The model’s forward method returns output passed through a sigmoid (have posted code for the whole model in the block below in case it’s helpful):

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
                        nn.BatchNorm3d(out_channels),
                        nn.ReLU())
        self.conv2 = nn.Sequential(
                        nn.Conv3d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        nn.BatchNorm3d(out_channels))
        self.downsample = downsample
        self.relu = nn.ReLU()
        self.out_channels = out_channels
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

class ResNet_3D_MFMs(nn.Module):

    def __init__(self, block, layers, num_classes = 1, n_input_channels=1):
        super(ResNet_3D_MFMs, self).__init__()
        self.inplanes = 4
        self.conv1 = nn.Sequential(
                        nn.Conv3d(n_input_channels, 4, kernel_size = 5, stride = 1, padding = 2), # first conv output channels?
                        nn.BatchNorm3d(4),
                        nn.ReLU())
        self.layer0 = self._make_layer(block, 4, layers[0], stride = 1)
        self.layer1 = self._make_layer(block, 8, layers[1], stride = 2)
        self.layer2 = self._make_layer(block, 16, layers[2], stride = 2)
        self.avgpool = nn.AdaptiveAvgPool3d((1,1,1))
        self.fc = nn.Linear(16, num_classes)
        
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv3d(self.inplanes, planes, kernel_size=1, stride=stride),
                nn.BatchNorm3d(planes),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
   
    def forward(self, x):
        x = self.conv1(x)
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = torch.sigmoid(self.fc(x))

        return x

And finally, my lightning module uses nn.BCELoss() with adam optimizer and lr scheduler:

class VideoClassificationLightningModule(pytorch_lightning.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model = myresnet.double()
        self.accuracy = torchmetrics.classification.BinaryAccuracy(task="binary", threshold=0.5)
        self.loss = nn.BCELoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        y_hat = self.model(batch["video"])
        targets = batch["label"].unsqueeze(1).double()

        # loss.backwards is called behind the scenes by PyTorchLightning
        loss = self.loss(y_hat, targets)
        acc = self.accuracy(y_hat, targets)
        
        # Log metrics to Tensorboard
        self.log_dict({
            "train_acc": acc, "train_loss": loss,
        }, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        y_hat = self.model(batch["video"])
        targets = batch["label"].unsqueeze(1).double()
        
        loss = self.loss(y_hat, targets)
        acc = self.accuracy(y_hat, targets)
        
        # Log metrics to tensorboard
        self.log_dict({
            "val_acc": acc, "val_loss": loss, 
        }, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.1)
        return {
            'optimizer': optimizer,
            'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001),
            'monitor': 'train_loss'
        }

After training for a while I can see that the accuracy and loss have plateaued without improvement:


So my questions are:

  1. Have I made a mistake with the model or training process? I realize it’s much smaller than a typical resnet, but my images are quite small and can only be downsampled so many times.
  2. What’s the best way to make the model bigger? Should I use wider convolutions instead of more layers?

Is your model training properly if you don’t wrap it into the lightning API?

Good idea - I’ve just run a few epochs without the lightning wrapper and it seems to perform the same (loss at ~0.69, no error reports).