Transfer Learning Grayscale, Image Size and Activation Function

Hi Munichma,

Welcome to the PyTorch community. In my opinion, PyTorch is an excellent framework to tackle your problem, so lets start.

The Custom Model
It looks like you want to alter the fully-connected layer by removing the Dropout layers, adding a sigmoid activation function and changing the number of output nodes (from 1000 to 10).

To achieve this, I think it is best to create a new model which uses the entire Alexnet feature extractor, and only uses the fully-connected layers that we want to keep:

# load the original Alexnet model
model = models.alexnet(pretrained=True)    

# create custom Alexnet
class CustomAlexnet(nn.Module):
    def __init__(self, num_classes):
        super(CustomAlexnet, self).__init__()
        self.features = nn.Sequential(*list(model.features.children()))
        self.classifier = nn.Sequential(
            *[list(model.classifier.children())[i] for i in [1, 2, 4, 5]],
            nn.Linear(4096, num_classes),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x

# load custom model
model2 = CustomAlexnet(num_classes=10)

Alright, in the code above we see that we use the entire feature extractor part of the original alexnet model by saving it as:

self.features = nn.Sequential(*list(model.features.children()))

Next, we keep the first and second fully-connected layer, along with the ReLU activation function, of the original model. We need the layer indices to select them. If we call model.classifier we get the original fully connected output:

Out[103]:
Sequential (
  (0): Dropout (p = 0.5)
  (1): Linear (9216 -> 4096)
  (2): ReLU (inplace)
  (3): Dropout (p = 0.5)
  (4): Linear (4096 -> 4096)
  (5): ReLU (inplace)
  (6): Linear (4096 -> 1000)
)

Since we do not want the last layer and the Dropout layers, we only need layer 1, 2, 4 and 5, along with our custom output layer:

self.classifier = nn.Sequential(
    *[list(model.classifier.children())[i] for i in [1, 2, 4, 5]],
    nn.Linear(4096, num_classes),
    nn.Sigmoid()
)

This method ensures the pre-trained weights are transferred to our custom model as well. However, it is best to verify it:

# feature layers parameters
print(list(model2.features.parameters()) == list(model.features.parameters())) # True

# first FC layer parameters
print(list(list(model2.classifier.children())[0].parameters()) \
      == list(list(model.classifier.children())[1].parameters())) # True

# second FC layer parameters
print(list(list(model2.classifier.children())[2].parameters()) \
      == list(list(model.classifier.children())[4].parameters())) # True

Freezing weights
Good, we now have our custom model all set with the correct pre-trained weights. Next we want to freeze all weights except the weights and biases of the last fully-connected layer. Every parameter has the .requires_grad option, which if set to False, will not be updated during training. Thus:

# freeze all layers
for param in model2.parameters():
    param.requires_grad = False

Only our last layer is frozen as well. Lets unfreeze it:

# unfreeze last fc layer
for layer_idx, param in enumerate(model2.classifier.parameters()):
    if layer_idx > 3:
        param.requires_grad = True

Lets verify if the unfrozen parameters are in fact from the correct layer:

# assert that the unfrozen weights are indeed last fc layer
unfrozen_weights = filter(lambda x: x.requires_grad, model2.parameters())
print(list(map(lambda x: x.size(), unfrozen_weights))) # [torch.Size([10, 4096]), torch.Size([10])]

Good, the shapes suggest that the weights and biases of the last fully-connected layer are unfrozen.

Image Size
We only need to address the problem of the input size and we’re good to go! PyTorch has very good support for data loading, image processing and creating batch iterators. I highly suggest checking out the torch.utils.data.DataLoader (for loading batches) and torchvision.datasets.ImageFolder (for loading and processing custom datasets) functionalities.

For the sake of this example, I will use the MNIST dataset because this dataset also has only 1 channel. We can define the preprocessing steps with torchvision.transforms, which will include concatenating the channels to create the ‘fake’ RGB images:

data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: torch.cat([x, x, x], 0))
])

Next, we load the data with these transformations. Note that ImageFolder also has the transform argument, so this will also work for custom datasets.

dataset = MNIST(root='./Desktop/', transform=data_transform) 

Finally, create a data loader from the dataset and were good!

data_loader = DataLoader(dataset, batch_size=1)

Lets see the result:

for x, y in data_loader:
    print(x.size()) # torch.Size([1, 3, 28, 28])
    break

Super, we have a ‘fake’ RGB image. I think I have covered everything from your post. I hope this will help you with the problem, but also teach some PyTorch fundamentals. If you have any questions please let me know!

Daniel

7 Likes