class AlexNet(BaseModel):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.ModuleDict({
"Conv2d_1":nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
#"BN_1":nn.BatchNorm2d(64),
"Act_1":nn.ReLU(inplace=True),
"Max_1":nn.MaxPool2d(kernel_size=3, stride=2),
"Conv2d_2":nn.Conv2d(64, 192, kernel_size=5, padding=2),
#"BN_2":nn.BatchNorm2d(192),
"Act_2":nn.ReLU(inplace=True),
"Max_2":nn.MaxPool2d(kernel_size=3, stride=2),
"Conv2d_3":nn.Conv2d(192, 384, kernel_size=3, padding=1),
#"BN_3":nn.BatchNorm2d(384),
"Act_3":nn.ReLU(inplace=True),
"Conv2d_4":nn.Conv2d(384, 256, kernel_size=3, padding=1),
#"BN_4":nn.BatchNorm2d(64),
"Act_4":nn.ReLU(inplace=True),
"Conv2d_5":nn.Conv2d(256, 256, kernel_size=3, padding=2),
#"BN_5":nn.BatchNorm2d(64),
"Act_5":nn.ReLU(inplace=True),
"Max_5":nn.MaxPool2d(kernel_size=3, stride=2)
})
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.fullyconnected = nn.ModuleDict({
"Pool":nn.AdaptiveAvgPool2d((6, 6)),
"drop_6":nn.Dropout(),
"Linear_6":nn.Linear(256 * 6 * 6, 4096),
#"BN_6":nn.BatchNorm1d(4096),
"Act_6":nn.ReLU(inplace=True),
"drop_7":nn.Dropout(),
"Linear_7":nn.Linear(4096, 4096),
#"BN_7":nn.BatchNorm1d(4096),
"Act_7":nn.ReLU(inplace=True),
"Linear_8":nn.Linear(4096, num_classes),
#"BN_8":nn.BatchNorm1d(num_classes),
#"Softmax":nn.LogSoftmax()
})
def forward(self, x):
x = self.features['Conv2d_1'](x)
x = self.features['Act_1'](x)
x = self.features['Max_1'](x)
x = self.features['Conv2d_2'](x)
x = self.features['Act_2'](x)
x = self.features['Max_2'](x)
x = self.features['Conv2d_3'](x)
x = self.features['Act_3'](x)
x = self.features['Conv2d_4'](x)
x = self.features['Act_4'](x)
x = self.features['Conv2d_5'](x)
x = self.features['Act_5'](x)
x = self.features['Max_5'](x)
x = self.avgpool(x)
x = x.view(-1, 256 * 6 * 6)
x = self.fullyconnected['Linear_6'](x)
x = self.fullyconnected['Act_6'](x)
x = self.fullyconnected['Linear_7'](x)
x = self.fullyconnected['Act_7'](x)
x = self.fullyconnected['Linear_8'](x)
return x
It outputs same val accuracy and loss, the network does not improve, when using CIFAR10 dataset.