RESNET18 step by step recreation help

Hi there.

I am trying to recreate ResNet18 in a linear, step-by-step to help my learning of convolutional networks.

I am able to get similar shapes as ResNet18, and the total parameters are close to Resnet18(summary attached below). This is a model for 3 classes.

However, the training of the model fails at the first block of the first residual layer.

[Conv -> BatchNorm -> Relu -> Conv -> BatchNorm -> Relu]x1 block cripples the model’s learning capability, even without the identity function. Can anyone enlighten me on why that is the case?

The model is able to run and train, but validation accuracy is flat and does not change at all.

The training and validation code works with the torchvisions resnet18 with none weights, so it is likely not an issue with the training code.

class CustomRes3(nn.Module):
    def __init__(self):
        super(CustomRes3,self).__init__()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7,stride=2,padding=3,bias=False)
        # start of layer1
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3,stride=1,padding=1,bias=False)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3,stride=1,padding=1,bias=False)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3,stride=1,padding=1,bias=False)
        self.conv5 = nn.Conv2d(64, 64, kernel_size=3,stride=1,padding=1,bias=False)
        # layer 2
        # bn2
        self.conv6_downsample = nn.Conv2d(64, 128, kernel_size=1,stride=2, bias=False)
        self.conv6 = nn.Conv2d(64, 128, kernel_size=3,stride=2,padding=1,bias=False) # downsample kern_size=1 if want?
        self.conv7 = nn.Conv2d(128, 128, kernel_size=3,stride=1,padding=1,bias=False)
        self.conv8 = nn.Conv2d(128, 128, kernel_size=3,stride=1,padding=1,bias=False)
        self.conv9 = nn.Conv2d(128, 128, kernel_size=3,stride=1,padding=1,bias=False)
        # bn3
        self.conv10_downsample = nn.Conv2d(128, 256, kernel_size=1,stride=2, bias=False)
        self.conv10 = nn.Conv2d(128, 256, kernel_size=3,stride=2,padding=1,bias=False)
        self.conv11 = nn.Conv2d(256, 256, kernel_size=3,stride=1,padding=1,bias=False)
        self.conv12 = nn.Conv2d(256, 256, kernel_size=3,stride=1,padding=1,bias=False)
        self.conv13 = nn.Conv2d(256, 256, kernel_size=3,stride=1,padding=1,bias=False)
        # bn4
        self.conv14_downsample = nn.Conv2d(256, 512, kernel_size=1,stride=2, bias=False)
        self.conv14 = nn.Conv2d(256, 512, kernel_size=3,stride=2,padding=1,bias=False)
        self.conv15 = nn.Conv2d(512, 512, kernel_size=3,stride=1,padding=1,bias=False)
        self.conv16 = nn.Conv2d(512, 512, kernel_size=3,stride=1,padding=1,bias=False)
        self.conv17 = nn.Conv2d(512, 512, kernel_size=3,stride=1,padding=1,bias=False)
        
        # define bn sizes
        self.bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.bn2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.bn3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.bn4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, 3, bias=True)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        ########### layer 1
        # Block 1
        identity = x
        x = self.conv2(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn1(x)
        x = x + identity
        x = self.relu(x)
        # Block 2
        identity = x
        x = self.conv4(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv5(x)
        x = self.bn1(x)
        x = x + identity
        x = self.relu(x)

        ############# layer 2
        # Block 1
        identity = x
        x = self.conv6(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv7(x)
        x = self.bn2(x)
        identity = self.conv6_downsample(identity)
        x = x + identity
        x = self.relu(x)
        # Block 2
        identity = x
        x = self.conv8(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv9(x)
        x = self.bn2(x)
        x = x + identity
        x = self.relu(x)
        
        ############# layer 3
        # Block 1
        identity = x
        x = self.conv10(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.conv11(x)
        x = self.bn3(x)
        identity = self.conv10_downsample(identity)
        x = x + identity
        x = self.relu(x)
        # Block 2
        identity = x
        x = self.conv12(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.conv13(x)
        x = self.bn3(x)
        x = x + identity
        x = self.relu(x)
        
        ############# layer 4
        # Block 1
        identity = x
        x = self.conv14(x)
        x = self.bn4(x)
        x = self.relu(x)
        x = self.conv15(x)
        x = self.bn4(x)
        identity = self.conv14_downsample(identity)
        x = x + identity
        x = self.relu(x)
        # Block 2
        identity = x
        x = self.conv16(x)
        x = self.bn4(x)
        x = self.relu(x)
        x = self.conv17(x)
        x = self.bn4(x)
        x = x + identity
        x = self.relu(x)

        # avg pool before flatten
        x = self.avgpool(x)
        x = x.reshape(x.size(0), -1)
        #x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.fc(x)
        # dont need softmax here because of CrossEntropyLoss
        return x

Summary

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Conv2d: 1-1                            [-1, 64, 112, 112]        9,408
├─BatchNorm2d: 1-2                       [-1, 64, 112, 112]        128
├─ReLU: 1-3                              [-1, 64, 112, 112]        --
├─MaxPool2d: 1-4                         [-1, 64, 56, 56]          --
├─Conv2d: 1-5                            [-1, 64, 56, 56]          36,864
├─BatchNorm2d: 1-6                       [-1, 64, 56, 56]          (recursive)
├─ReLU: 1-7                              [-1, 64, 56, 56]          --
├─Conv2d: 1-8                            [-1, 64, 56, 56]          36,864
├─BatchNorm2d: 1-9                       [-1, 64, 56, 56]          (recursive)
├─ReLU: 1-10                             [-1, 64, 56, 56]          --
├─Conv2d: 1-11                           [-1, 64, 56, 56]          36,864
├─BatchNorm2d: 1-12                      [-1, 64, 56, 56]          (recursive)
├─ReLU: 1-13                             [-1, 64, 56, 56]          --
├─Conv2d: 1-14                           [-1, 64, 56, 56]          36,864
├─BatchNorm2d: 1-15                      [-1, 64, 56, 56]          (recursive)
├─ReLU: 1-16                             [-1, 64, 56, 56]          --
├─Conv2d: 1-17                           [-1, 128, 28, 28]         73,728
├─BatchNorm2d: 1-18                      [-1, 128, 28, 28]         256
├─ReLU: 1-19                             [-1, 128, 28, 28]         --
├─Conv2d: 1-20                           [-1, 128, 28, 28]         147,456
├─BatchNorm2d: 1-21                      [-1, 128, 28, 28]         (recursive)
├─Conv2d: 1-22                           [-1, 128, 28, 28]         8,192
├─ReLU: 1-23                             [-1, 128, 28, 28]         --
├─Conv2d: 1-24                           [-1, 128, 28, 28]         147,456
├─BatchNorm2d: 1-25                      [-1, 128, 28, 28]         (recursive)
├─ReLU: 1-26                             [-1, 128, 28, 28]         --
├─Conv2d: 1-27                           [-1, 128, 28, 28]         147,456
├─BatchNorm2d: 1-28                      [-1, 128, 28, 28]         (recursive)
├─ReLU: 1-29                             [-1, 128, 28, 28]         --
├─Conv2d: 1-30                           [-1, 256, 14, 14]         294,912
├─BatchNorm2d: 1-31                      [-1, 256, 14, 14]         512
├─ReLU: 1-32                             [-1, 256, 14, 14]         --
├─Conv2d: 1-33                           [-1, 256, 14, 14]         589,824
├─BatchNorm2d: 1-34                      [-1, 256, 14, 14]         (recursive)
├─Conv2d: 1-35                           [-1, 256, 14, 14]         32,768
├─ReLU: 1-36                             [-1, 256, 14, 14]         --
├─Conv2d: 1-37                           [-1, 256, 14, 14]         589,824
├─BatchNorm2d: 1-38                      [-1, 256, 14, 14]         (recursive)
├─ReLU: 1-39                             [-1, 256, 14, 14]         --
├─Conv2d: 1-40                           [-1, 256, 14, 14]         589,824
├─BatchNorm2d: 1-41                      [-1, 256, 14, 14]         (recursive)
├─ReLU: 1-42                             [-1, 256, 14, 14]         --
├─Conv2d: 1-43                           [-1, 512, 7, 7]           1,179,648
├─BatchNorm2d: 1-44                      [-1, 512, 7, 7]           1,024
├─ReLU: 1-45                             [-1, 512, 7, 7]           --
├─Conv2d: 1-46                           [-1, 512, 7, 7]           2,359,296
├─BatchNorm2d: 1-47                      [-1, 512, 7, 7]           (recursive)
├─Conv2d: 1-48                           [-1, 512, 7, 7]           131,072
├─ReLU: 1-49                             [-1, 512, 7, 7]           --
├─Conv2d: 1-50                           [-1, 512, 7, 7]           2,359,296
├─BatchNorm2d: 1-51                      [-1, 512, 7, 7]           (recursive)
├─ReLU: 1-52                             [-1, 512, 7, 7]           --
├─Conv2d: 1-53                           [-1, 512, 7, 7]           2,359,296
├─BatchNorm2d: 1-54                      [-1, 512, 7, 7]           (recursive)
├─ReLU: 1-55                             [-1, 512, 7, 7]           --
├─AdaptiveAvgPool2d: 1-56                [-1, 512, 1, 1]           --
├─Linear: 1-57                           [-1, 3]                   1,539
==========================================================================================
Total params: 11,170,371
Trainable params: 11,170,371
Non-trainable params: 0
Total mult-adds (G): 1.81
==========================================================================================
Input size (MB): 0.57
Forward/backward pass size (MB): 26.41
Params size (MB): 42.61
Estimated Total Size (MB): 69.60
==========================================================================================

I would generally recommend making sure the number of parameters and buffers is equal.
Based on your code the number of batchnorm layers is too low as you are reusing the same layers a few times and only initialize 4 nn.BatchNorm2d layers while the original model has 20 nn.BatchNorm2d instances.

Also note that your downsample layers are plain conv layers:

self.conv6_downsample = nn.Conv2d(64, 128, kernel_size=1,stride=2, bias=False)

while the reference uses:

model.layer2[0].downsample

Sequential(
  (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
  (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

Hi ptrblck!

Thanks I got it to working with the same parameters as the model package!

Learning point is that the BatchNorm cannot be reused like a function.