Transfer learning, extra input channels

tldr:
I want to transfer trained weights from a conv layer like:
self.conv_init = nn.Conv1d(num_in_channels, num_hidden, 1)

to a new one with more input channels (initializing the rest with 0’s or according to some initialization scheme) - and everything else remaining same, like eg. :
self.conv_init = nn.Conv1d(num_in_channels+5, num_hidden, 1).

As far as i know load_state_dict will either ignore that layer or throw an error for the size mismatch. Is there a built in way to do the above? …Or another way?


story for context:
I’m working on building a net for time series data. The dataset contains sensor (let’s call them group a) readings since early 00’s. At some point after '17 extra sensors (‘group b’) where added.
I’d like to train a model on the 1st original group (to take advantage of the 2 decades of data) , and then use that as a ‘starting point’ to train another including group b.
The input is an 1d multichannel tensor, that’s a fixed size sliding window on the data where every channel is a sensor’s readings.
The 1st conv layer (shown/mentioned above) has kernel size 1, and is used to create the hidden channels that the rest of the layers will use, that will stay the same between the two models.

You could load the state_dict before any manipulation of the first conv layer to make sure all other layers are properly restored.
Once this is done, you could replace the first conv layer with a new one and reassign the parameters manually. This code snippet give you an example:

ref = nn.Conv2d(3, 64, 3, 1, 1)
conv = nn.Conv2d(5, 64, 3, 1, 1)

with torch.no_grad():
    conv.weight[:, :3] = ref.weight
    conv.weight[:, 3:] = torch.zeros_like(conv.weight[:, 3:])
    conv.bias = ref.bias
    
x = torch.randn(2, 3, 24, 24)
out1 = ref(x)
out2 = conv(torch.cat((x, torch.randn(2, 2, 24, 24)), dim=1))

print((out1 == out2).all())
> tensor(True)
1 Like