I am trying to train a ResNet model for CIFAR10 in pytorch. The model is just not training, i.e., the accuracy is not improving at all. I tried to overfit a sample data of 10 batches(batch size 10), but it failed miserably.
While when I change the model from my custom ResNet code to torchvision.models.resnet18, everything works perfectly fine. So I am assuming issue is with my custom ResNet architecture. Would appreciate if someone can help me figure out the issue. Here’s my custom ResNet model :
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn
import random
import argparse
class EmptyShortcutLayer(nn.Module):
def __init__(self):
super(EmptyShortcutLayer, self).__init__()
def forward(self, x):
return x
class ResNet18(nn.Module):
# do all layer manually
# DO NOT CREATE sequential for blocks
def __init__(self, input_size, block="Basic", num_classes=10):
super(ResNet18, self).__init__()
# blocks 2,2,2,2
self.in_planes = 64
self.input_size = input_size
self.output_size = self.input_size
self.num_classes = num_classes
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
# batch normalization takes C as argument, i.e. the channels
self.bn1 = nn.BatchNorm2d(64)
# segment 1
# block 1
self.block_1_conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.block_1_bn1 = nn.BatchNorm2d(64)
self.block_1_conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.block_1_bn2 = nn.BatchNorm2d(64)
self.block_1_shortcut = EmptyShortcutLayer()
# block 2
self.block_2_conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.block_2_bn1 = nn.BatchNorm2d(64)
self.block_2_conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.block_2_bn2 = nn.BatchNorm2d(64)
self.block_2_shortcut = EmptyShortcutLayer()
# segment 2
# block 3
self.block_3_conv1 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False)
self.block_3_bn1 = nn.BatchNorm2d(128)
# stride 2, so half it
self.output_size = self.output_size/2
self.block_3_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.block_3_bn2 = nn.BatchNorm2d(128)
self.block_3_shortcut = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128))
# block 4
self.block_4_conv1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.block_4_bn1 = nn.BatchNorm2d(128)
self.block_4_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.block_4_bn2 = nn.BatchNorm2d(128)
self.block_4_shortcut = EmptyShortcutLayer()
# segment 3
# block 5
self.block_5_conv1 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False)
self.block_5_bn1 = nn.BatchNorm2d(256)
# stride =2, output size is halved
self.output_size = self.output_size/2
self.block_5_conv2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.block_5_bn2 = nn.BatchNorm2d(256)
self.block_5_shortcut = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(256))
# block 6
self.block_6_conv1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.block_6_bn1 = nn.BatchNorm2d(256)
self.block_6_conv2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.block_6_bn2 = nn.BatchNorm2d(256)
self.block_6_shortcut = EmptyShortcutLayer()
# segment 4
# block 7
self.block_7_conv1 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.block_7_bn1 = nn.BatchNorm2d(512)
# stride=2, output size is halved
self.output_size = self.output_size/2
self.block_7_conv2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.block_7_bn2 = nn.BatchNorm2d(512)
self.block_7_shortcut = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(512))
# block 8
self.block_8_conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.block_8_bn1 = nn.BatchNorm2d(512)
self.block_8_conv2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.block_8_bn2 = nn.BatchNorm2d(512)
self.block_8_shortcut = EmptyShortcutLayer()
# avg pooling
print("self output isze",self.output_size)
self.avg_pool = nn.AvgPool2d(self.output_size)
self.output_size = self.output_size/4
# final layer
self.linear = nn.Linear(512, self.num_classes)
def forward(self, x):
y = x
# first conv
y = self.conv1(y)
y = self.bn1(y)
# first block
out = self.block_1_conv1(y)
out = self.block_1_bn1(out)
out = self.block_1_conv2(out)
out = self.block_1_bn2(out)
y = self.block_1_shortcut(y) + out
# second block
out = self.block_2_conv1(y)
out = self.block_2_bn1(out)
out = self.block_2_conv2(out)
out = self.block_2_bn2(out)
y = self.block_2_shortcut(y) + out
# third block
out = self.block_3_conv1(y)
out = self.block_3_bn1(out)
out = self.block_3_conv2(out)
out = self.block_3_bn2(out)
y = self.block_3_shortcut(y) + out
# fourth block
out = self.block_4_conv1(y)
out = self.block_4_bn1(out)
out = self.block_4_conv2(out)
out = self.block_4_bn2(out)
y = self.block_4_shortcut(y) + out
# fifth block
out = self.block_5_conv1(y)
out = self.block_5_bn1(out)
out = self.block_5_conv2(out)
out = self.block_5_bn2(out)
y = self.block_5_shortcut(y) + out
# sixth block
out = self.block_6_conv1(y)
out = self.block_6_bn1(out)
out = self.block_6_conv2(out)
out = self.block_6_bn2(out)
y = self.block_6_shortcut(y) + out
# seveth block
out = self.block_7_conv1(y)
out = self.block_7_bn1(out)
out = self.block_7_conv2(out)
out = self.block_7_bn2(out)
y = self.block_7_shortcut(y) + out
# eigth block
out = self.block_8_conv1(y)
out = self.block_8_bn1(out)
out = self.block_8_conv2(out)
out = self.block_8_bn2(out)
y = self.block_8_shortcut(y) + out
# avg pool
y = self.avg_pool(y)
y = y.view(y.size(0), -1)
y = self.linear(y)
return y
I have explicitly created all layers in ResNet18 in this code, as after training I want to visualize the output of each layer.
I would appreciate if someone can give me some pointers to debug this.