Inquiry about concept of transfer leaarning

Hello,
I’ve encountered a conceptual question about transfer learning.
Suppose that I recall resnet50 model (pretrained module is True) and then, for my personal aim, I change the first layer (such that it accepts one channel image now) and the last layer (in this case, I add an additional fc layer to the end of the network). Now, my question is, what happens to the weights of the first layer (which I change that). In addition, what are the initial weights of my last layer (which wasn’t trained in the original one). Lastly, do the weights of 3 channel model fit properly to a one channel model?

thanks

You could try to add the three input channel contributions to get the same result as copying your input to all three channels. Ideally, inputs should be globally normalized to zero mean unit variance for that.

2 Likes

The weights of the first layer would be the same as the pretrained ResNet50 that you are using as you only changing the input size, assuming that you have not done any finetuning or training with target data.

Since the last layer is just declared but not trained yet, it would be initialized with a default weight initializer.
https://discuss.pytorch.org/t/how-are-layer-weights-and-biases-initialized-by-default/1307
If its a linear layer check [this] line 57-64.(pytorch/linear.py at master · pytorch/pytorch · GitHub)

You would benefit from pretrained model if learned weights of all three channels have similar distribution. But you still need to train (most likely) all layers with target data.

1 Like

Thank you for your reply. I have 2 more questions:

  1. Are the number of weights for 1 channel input and 3 channel input the same? (in my opinion, they shouldn’t be the same and, therefore, how does the network pick up its weights for my 1 channel input?).
  2. for the last part, what did you mean by ‘if learned weights of all three channels have similar distribution’.?

The number of weights for 1 channel input and 3 channel input won’t be same. In fact, number of channels in input layer always need to be matched in the first layer.
You may try changing the number of channels from 3 to 1 in the PyTorch CIFAR-10 tutorial. You can see the number of parameters with torch-summary package.

from torchsummary import summary
summary(net, ( 3, 32, 32))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Conv2d: 1-1                            [-1, 6, 28, 28]           456
├─MaxPool2d: 1-2                         [-1, 6, 14, 14]           --
├─Conv2d: 1-3                            [-1, 16, 10, 10]          2,416
├─MaxPool2d: 1-4                         [-1, 16, 5, 5]            --
├─Linear: 1-5                            [-1, 120]                 48,120
├─Linear: 1-6                            [-1, 84]                  10,164
├─Linear: 1-7                            [-1, 10]                  850
==========================================================================================

If you now change the input to 1 channel, you would need to change the definition of network’s 1st layer to accept 1 channel. You may then pass a single channel input and see the model parameters as

summary(net, ( 1, 32, 32))

You would see less no-of-parameters in the first layer. To be exact, 156 as compared to 456 with 3 channels.


==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Conv2d: 1-1                            [-1, 6, 28, 28]           156
├─MaxPool2d: 1-2                         [-1, 6, 14, 14]           --
├─Conv2d: 1-3                            [-1, 16, 10, 10]          2,416
├─MaxPool2d: 1-4                         [-1, 16, 5, 5]            --
├─Linear: 1-5                            [-1, 120]                 48,120
├─Linear: 1-6                            [-1, 84]                  10,164
├─Linear: 1-7                            [-1, 10]                  850
==========================================================================================

Hope this helps.

1 Like

thank you a lot. It was helpful