Transfer Learning: How to modify the first conv2d layer of Alexnet to accomodate for 9 channel input?

I am using Alexnet as a feature extractor on a 9 channel image (3 images concatenated). So I need to modify the first conv2d layer to receive 9 channel inputs. What I tried:

model_conv = torchvision.models.alexnet(pretrained=True) new_features = nn.Sequential(*list(model_conv.features.children())) model_conv.features = new_features new_9_tensor = torch.cat((new_features.state_dict().get('0.weight'), new_features.state_dict().get('0.weight'), new_features.state_dict().get('0.weight')), 1) new_features.state_dict().__setitem__('0.weight', new_9_tensor) print("named modules", list(new_features.named_modules()))

Although the state_dict is modified to show the required 64x9x11x11 tensor, the output of last line (list of named modules) shows no change in conv2d layer. Can someone please help me on what changes do I need to make to make Alexnet work for a 9 channel input? Is there a easier way to do this without accessing private attributes?

Is there a way to remove first layer and initialize it with Conv2d(9, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))

Any sort of help is really appreciated. Thanks!

I believe you want Conv2d(9, 64, kernel_size=11, stride=4, padding=2)

1 Like

My bad! typo in the post but not in the original code! Reinitialization of conv2d will randomly initialize the weights. Besides, the self.features = nn.Sequential( ... container doesn’t allow to change only the first conv2d layer (please correct me if I’m wrong) while keeping all the weights as they are that it received after the first statement above. Still looking for a clue on how to modify the first layer :neutral_face:

I think the easiest way to do this is to create your model by modifying the Alexnet model from Vision so that the first conv layer has 9 channels

Then to copy the weights over you can copy them from the original Alexnet state_dict to the state_dict of your modified model. It should be direct assignment except for the first conv layer.

4 Likes

Does this work very well? Is there any way better?