Transfer Learning Grayscale, Image Size and Activation Function

Hey guys,

I am trying to do the following but I am new to PyTorch and the tutorial about transfer learning is really a rare special case and I cannot find the information I need in order to apply my problem and setup onto it.

What I want to do is:
I have a large dataset and I want to use most of alexnets pretrained weights and finetune.
What I want to do in steps:

  1. Remove last FC layer from alexnet and implement my own FC layer with 10 classes but the activation function should be sigmoid.
  2. I want to freeze all other weights and only train the model for the new FC layer I created above.
  3. My images are grayscale (1 channel) and 256x256 size, so how do I handle this and what do I have to change because I think the models are trained with 3 channel RGB images and usually another size like 224x224. I cannot change the size of my images because I am supposed to stick to that size.

My ideas:

  1. Load the Alexnet model (with pretrained = True).
  2. Freeze all parameters/weights.
  3. Remove last FC layer and replace by my own 10 class FC layer.
  4. Normalise the output of the FC layer and apply sigmoid onto it.
  5. Change the activation function of that FC layer above to sigmoid and get rid of the dropout in the Alexnet.
  6. Access the forward function of Alexnet and do something like x = torch.cat((x,x,x), 0) in
    in order to create a fake RGB image.

But I do not know how to stick these concepts together.
Help highly appreciated and thank you!

Thank you!

1 Like

Hi Munichma,

Welcome to the PyTorch community. In my opinion, PyTorch is an excellent framework to tackle your problem, so lets start.

The Custom Model
It looks like you want to alter the fully-connected layer by removing the Dropout layers, adding a sigmoid activation function and changing the number of output nodes (from 1000 to 10).

To achieve this, I think it is best to create a new model which uses the entire Alexnet feature extractor, and only uses the fully-connected layers that we want to keep:

# load the original Alexnet model
model = models.alexnet(pretrained=True)    

# create custom Alexnet
class CustomAlexnet(nn.Module):
    def __init__(self, num_classes):
        super(CustomAlexnet, self).__init__()
        self.features = nn.Sequential(*list(model.features.children()))
        self.classifier = nn.Sequential(
            *[list(model.classifier.children())[i] for i in [1, 2, 4, 5]],
            nn.Linear(4096, num_classes),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x

# load custom model
model2 = CustomAlexnet(num_classes=10)

Alright, in the code above we see that we use the entire feature extractor part of the original alexnet model by saving it as:

self.features = nn.Sequential(*list(model.features.children()))

Next, we keep the first and second fully-connected layer, along with the ReLU activation function, of the original model. We need the layer indices to select them. If we call model.classifier we get the original fully connected output:

Out[103]:
Sequential (
  (0): Dropout (p = 0.5)
  (1): Linear (9216 -> 4096)
  (2): ReLU (inplace)
  (3): Dropout (p = 0.5)
  (4): Linear (4096 -> 4096)
  (5): ReLU (inplace)
  (6): Linear (4096 -> 1000)
)

Since we do not want the last layer and the Dropout layers, we only need layer 1, 2, 4 and 5, along with our custom output layer:

self.classifier = nn.Sequential(
    *[list(model.classifier.children())[i] for i in [1, 2, 4, 5]],
    nn.Linear(4096, num_classes),
    nn.Sigmoid()
)

This method ensures the pre-trained weights are transferred to our custom model as well. However, it is best to verify it:

# feature layers parameters
print(list(model2.features.parameters()) == list(model.features.parameters())) # True

# first FC layer parameters
print(list(list(model2.classifier.children())[0].parameters()) \
      == list(list(model.classifier.children())[1].parameters())) # True

# second FC layer parameters
print(list(list(model2.classifier.children())[2].parameters()) \
      == list(list(model.classifier.children())[4].parameters())) # True

Freezing weights
Good, we now have our custom model all set with the correct pre-trained weights. Next we want to freeze all weights except the weights and biases of the last fully-connected layer. Every parameter has the .requires_grad option, which if set to False, will not be updated during training. Thus:

# freeze all layers
for param in model2.parameters():
    param.requires_grad = False

Only our last layer is frozen as well. Lets unfreeze it:

# unfreeze last fc layer
for layer_idx, param in enumerate(model2.classifier.parameters()):
    if layer_idx > 3:
        param.requires_grad = True

Lets verify if the unfrozen parameters are in fact from the correct layer:

# assert that the unfrozen weights are indeed last fc layer
unfrozen_weights = filter(lambda x: x.requires_grad, model2.parameters())
print(list(map(lambda x: x.size(), unfrozen_weights))) # [torch.Size([10, 4096]), torch.Size([10])]

Good, the shapes suggest that the weights and biases of the last fully-connected layer are unfrozen.

Image Size
We only need to address the problem of the input size and we’re good to go! PyTorch has very good support for data loading, image processing and creating batch iterators. I highly suggest checking out the torch.utils.data.DataLoader (for loading batches) and torchvision.datasets.ImageFolder (for loading and processing custom datasets) functionalities.

For the sake of this example, I will use the MNIST dataset because this dataset also has only 1 channel. We can define the preprocessing steps with torchvision.transforms, which will include concatenating the channels to create the ‘fake’ RGB images:

data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: torch.cat([x, x, x], 0))
])

Next, we load the data with these transformations. Note that ImageFolder also has the transform argument, so this will also work for custom datasets.

dataset = MNIST(root='./Desktop/', transform=data_transform) 

Finally, create a data loader from the dataset and were good!

data_loader = DataLoader(dataset, batch_size=1)

Lets see the result:

for x, y in data_loader:
    print(x.size()) # torch.Size([1, 3, 28, 28])
    break

Super, we have a ‘fake’ RGB image. I think I have covered everything from your post. I hope this will help you with the problem, but also teach some PyTorch fundamentals. If you have any questions please let me know!

Daniel

7 Likes

Hey Daniel!

Thank you so much for this extremely detailed answer. I just tried it out but unfortunately I am getting this strange message:

File "<ipython-input-3-4a6dcde8a731>", line 11
  nn.Linear(4096, num_classes),
SyntaxError: only named arguments may follow *expression

I am not able to modify the self.classifier by adding these lines:

nn.Linear(4096, num_classes),
nn.Sigmoid()

These two lines above seem to cause the trouble because without these two statements, I am at least able to pick the options 1,2,4,5 I want out of the Alexnet classifier.

Do you know why it is not working because basically I just copied 1 to 1 the code above
Moreover, I wanted to ask you where I am supposed to use these weight freezing and unfreezing statements:

freeze all layers

for param in model2.parameters():
    param.requires_grad = False

unfreeze last fc layer

for layer_idx, param in enumerate(model2.classifier.parameters()):
if layer_idx > 3:
    param.requires_grad = True

within the defintion of my CustomAlexnet class or can this be be done after I have initialized my own net with this statement:

load custom model

model2 = CustomAlexnet(num_classes=10)

Thanks a lot because I was about to give up diving into PyTorch and going over to Keras because there are barely any examples and tutorials doing similar things with PyTorch.

Amine

Hi Amine,

Are you using Python 2? Because since Python 3.5 starred expressions have been generalized so that named expressions can be used afterwards (PEP 448).

A Python 2 compatible workaround could be:

self.classifier = nn.Sequential(
    *[list(model.classifier.children())[i] for i in [1, 2, 4, 5]] + \
    [nn.Linear(4096, num_classes),
    nn.Sigmoid()]
)

so that it is one big starred expression instead of named expressions that follow it.

The layer freezing is done outside the CustomAlexnet definition, after you have initialized the network.

Hope it works.

Daniel

Thank you a lot, yeah I am using Python 2.

I solved it by taking the classifier as a list and appending nn.Linear(4096,num_class)
and than again appending nn.Sigmoid().

Another last question:

Does Alexnet need images of exactly 224x224. When I was trying to run it with 256x256 it was not working and I got this strange error message:

pytorch RuntimeError: size mismatch at /opt/conda/conda-bld/pytorch_1503966894950/work/torch/lib/THC/generic/THCTensorMathBlas.cu:243

On the other hand I got Resnet18 working with images of size 256x256. Could you explain why this is the case?

Thanks!

Image input size depends on the network’s architectural decisions.

Parameters such as the stride and kernel size dictate the layers output sizes, which forthcoming layers are then counting on. So if you use a different input size, it might end up conflicting with the pre-defined set of tensor operations.

Sometimes you can get away with using different input sizes (if there is no conflict in tensor dimensions), but it is recommended to use the original input size the network was designed for (I think ResNet and AlexNet are 224x224 while Inception is 299x299).

Hey thanks!

Then I think its possible to get away with different images size for the Resnet18 implementation of PyTorch’s models zoo but it does not work for Alexnet, since I got the error message above (just updated it because I was searching for it) which can be fixed when changing the input images to 224x224.