How to change no of input channels to a pretrained model?

Hi,

I have loaded the pre-trained AlexNet model in Pytorch. Now, I have images of 6 channels. So, I want to change the num_of_input_channels of the first convolutional layer to 6 from 3. How can I do that? Also, I want to use the pre-trained weights for the RGB channels and for other 3 I want to initialize the weights with random numbers.

Any help will be appreciated. Thanks.

Hi,

If you want to pass 6-channel images to AlexNet directly, there will be a need to replace the 1st convolution layer of AlexNet.
Because the model is defined using nn.Sequential code, it’s a bit confusing to replacing the specific layer.

You can prepare the pretrained & not trained AlexNet by torchvision.models.alexnet(pretrained=True) and torchvision.models.alexnet(pretrained=False) respectively. Then, split the 6-channel images into 2 3-channel images before pass them to the 2 alexnets.

I am using the following code for a 4 channel input-


class RGBMaskEncoderCNN(nn.Module):
	def __init__(self):
		super(RGBMaskEncoderCNN, self).__init__()

		self.alexnet = models.alexnet(pretrained=True)
		# get the pre-trained weights of the first layer
		pretrained_weights = self.alexnet.features[0].weight
		new_features = nn.Sequential(*list(self.alexnet.features.children()))
		new_features[0] = nn.Conv2d(4, 64, kernel_size=11, stride=4, padding=2)

		# For M-channel weight should randomly initialized with Gaussian
		new_feaures_[0].weight.data.normal_(0, 0.001)
		# For RGB it should be copied from pretrained weights
		new_features[0].weight.data[:, :3, :, :] = pretrained_weights

		self.alexnet.features = new_features

		
		

	def forward(self, images):
		"""Extract Feature Vector from Input Images"""
		features = self.alexnet(images)
		return features

Don’t know if this is the correct way to do it or not.

7 Likes

Hi,

Can you please specify if this solution did work out correctly for you?

This actually leads to an error of ‘Sequential’ object does not support item assignment.
Is there any solution to this?

I was thinking to use a for loop to rebuild the whole object starting from the new feature, and followed by the rest of the alexnet.

2 Likes

This code works for me.