Hi All,
I’m trying to remodel alexnet to a binary classifier. I wanted to add a Softmax layer to the classifier of the pretrained AlexNet to interpret the output of the last layer as probabilities. Till now the code I have written is -
model_ft = models.alexnet(pretrained=True)
# Frozen the weights of the cnn layers towards the beginning
layers_to_freeze = [model_ft.features[0],model_ft.features[3],model_ft.features[6],model_ft.features[8]]
for layer in layers_to_freeze:
for params in layer.parameters():
params.requires_grad = False
model_ft.classifier[6] = nn.Linear(in_features=4096, out_features=2, bias=True)
model_ft = nn.Sequential(model_ft,nn.Softmax(1))
And after this the map of the model shows -
Sequential(
(0): AlexNet(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(inplace)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(inplace)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace)
(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace)
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
(classifier): Sequential(
(0): Dropout(p=0.5)
(1): Linear(in_features=9216, out_features=4096, bias=True)
(2): ReLU(inplace)
(3): Dropout(p=0.5)
(4): Linear(in_features=4096, out_features=4096, bias=True)
(5): ReLU(inplace)
(6): Linear(in_features=4096, out_features=2, bias=True)
)
)
(1): Softmax()
)
Even after adding this layer, I’m getting output tensors of dimension (4,2)
output = model(inputs)
print(output.shape)
This outputs - torch.Size([4, 2])
Why is the output of shape(4,2) ? Shouldn’t it be of shape (1,2), shown in this answer ?