Change input dimensions of a pre-trained UNet

I am a relative newcomer to DL, and as such, I don’t have a clear grasp of what information is necessary and what isn’t when requesting help from an online community of programmers. So I apologize in advance for the wall of text you’re about to witness!

For my masters thesis, I’m replicating a paper that uses a UNet to analyze satellite imagery and generate maps showing forest cover in a particular region. The authors used the VGG11 pre-trained model, but this is not the pre-trained model that I have questions about.

Link to the paper, in case anyone’s interested!

Every image from the Landsat-8 satellite has several “bands”, and each one corresponds to a very specific range of wavelengths on the electromagnetic spectrum. There are 4 different configurations for the input data being fed to the model, based on the number of “bands” in the satellite image (18, 11, 7, or 3).

Simply put, “bands” in this context = “input channels” in the context of a neural network.

The authors compare the performance of their novel UNet model with traditional ML classifiers for each input configuration (18 bands, 11 bands, etc.) and conclude that an image consisting of 18 bands provides the best overall performance. Presumably, they saved the best performing UNet model for each configuration as a “.pt” file.

However, in the GitHub README, they’ve only provided the pre-trained UNet model as a “.pt” file for the 18 band input data.

Since I’m replicating their paper, I am also having to test the UNet’s performance on all 4 input configurations. However, since they only provided the 18 band pre-trained model, I cannot test the performance of the other 3 configurations. If I do try to feed an 11 band input to the UNet while using the 18 band pre-trained model as a checkpoint, I get the following error:

Traceback (most recent call last):
  
File "train.py", line 84, in <module>
    main(config)
 
File "train.py", line 55, in main
    trainer = Trainer(model, criterion, metrics, optimizer,
  
File "/scratch/sjayaramannaga/AI-ForestWatch-Srinath/trainer/trainer.py", 
line 22, in __init__
super().__init__(model, criterion, metric_ftns, optimizer, config)
  
File "/scratch/sjayaramannaga/AI-ForestWatch-Srinath/base/base_trainer.py", 
line 51, in __init__
    self._resume_checkpoint(cfg_trainer['pretrained_model'])
  
File "/scratch/sjayaramannaga/AI-ForestWatch-Srinath/base/base_trainer.py", 
line 141, in _resume_checkpoint
    self.model.load_state_dict(torch.load(resume_path), strict=False)
  
File "/home/sjayaramannaga/.local/lib/python3.8/site-packages/torch/nn/modules/
module.py", line 1482, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'
    .format(RuntimeError: Error(s) in loading state_dict for UNet:
	size mismatch for encoder_1.conv1.weight: copying a param with shape 
torch.Size([64, 18, 3, 3]) from checkpoint, the shape in current model
is torch.Size([64, 11, 3, 3]).

When I reached out to one of the authors he suggested that I “load the pretrained weights in the model with 18 channels, change the input dimension of the first layer, and retrain the model”.

How do I do what he suggested? I tried loading the pre-trained model in a Jupyter Notebook using torch.load() and then I printed it. The “.pt” file itself contains only the weights used by the UNet, and nothing more. It is just an Ordered Dict that contains a number of tensor arrays, as shown in the screenshot below:

From what I have understood so far, if I have to re-train this 18 band UNet, I have to change encoder_1.conv1 to be (11, 64, conv_1=nn.Conv2d(11, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))).

But I have no idea how to do that with the “.pt” file they’ve provided.

I have provided the code for the model below.

"""
    UNet model definition in here
"""

import torch
import torch.nn as nn
from base import BaseModel
from torch.optim import *
from torchvision import models


class UNet_down_block(BaseModel):
    """
        Encoder class
    """

    def __init__(self, input_channel, output_channel, conv_1=None, conv_2=None):
        super(UNet_down_block, self).__init__()
        if conv_1:
            print('LOG: Using pretrained convolutional layer', conv_1)
        if conv_2:
            print('LOG: Using pretrained convolutional layer', conv_2)
        self.input_channels = input_channel
        self.output_channels = output_channel
        self.conv1 = conv_1 if conv_1 else nn.Conv2d(input_channel, 
                             output_channel, kernel_size=3, padding=1)
        self.conv2 = conv_2 if conv_2 else nn.Conv2d(output_channel, 
                             output_channel, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=output_channel)
        self.bn2 = nn.BatchNorm2d(num_features=output_channel)
        self.activate = nn.ReLU()

    def forward(self, x):
        x = self.activate(self.bn1(self.conv1(x)))
        x = self.activate(self.bn2(self.conv2(x)))
        return x


class UNet_up_block(BaseModel):
    """
        Decoder class
    """

    def __init__(self, prev_channel, input_channel, output_channel):
        super(UNet_up_block, self).__init__()
        self.output_channels = output_channel
        self.tr_conv_1 = nn.ConvTranspose2d(input_channel, input_channel, 
                                               kernel_size=2, stride=2)
        self.conv_1 = nn.Conv2d(prev_channel+input_channel, output_channel, 
                            kernel_size=3, stride=1, padding=1)
        self.conv_2 = nn.Conv2d(output_channel, output_channel, kernel_size=3, 
                           stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=output_channel)
        self.bn2 = nn.BatchNorm2d(num_features=output_channel)
        self.activate = nn.ReLU()

    def forward(self, prev_feature_map, x):
        x = self.tr_conv_1(x)
        x = self.activate(x)
        x = torch.cat((x, prev_feature_map), dim=1)
        x = self.activate(self.bn1(self.conv_1(x)))
        x = self.activate(self.bn2(self.conv_2(x)))
        return x


class UNet(BaseModel):
    def __init__(self, topology, input_channels, num_classes):
        super(UNet, self).__init__()
        # these topologies are possible right now
        self.topologies = {
            "ENC_1_DEC_1": self.ENC_1_DEC_1,
            "ENC_2_DEC_2": self.ENC_2_DEC_2,
            "ENC_3_DEC_3": self.ENC_3_DEC_3,
            "ENC_4_DEC_4": self.ENC_4_DEC_4,
        }
        assert topology in self.topologies
        vgg_trained = models.vgg11(pretrained=True)
        pretrained_layers = list(vgg_trained.features)
        self.max_pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout2d(0.6)
        self.activate = nn.ReLU()
        self.encoder_1 = UNet_down_block(input_channels, 64)
        self.encoder_2 = UNet_down_block(64, 128, conv_1=pretrained_layers[3])
        self.encoder_3 = UNet_down_block(128, 256, conv_1=pretrained_layers[6], 
                            conv_2=pretrained_layers[8])
        self.encoder_4 = UNet_down_block(256, 512, conv_1=pretrained_layers[11], 
                            conv_2 = pre_trained_layers[13])
        self.mid_conv_64_64_a = nn.Conv2d(64, 64, 3, padding=1)
        self.mid_conv_64_64_b = nn.Conv2d(64, 64, 3, padding=1)
        self.mid_conv_128_128_a = nn.Conv2d(128, 128, 3, padding=1)
        self.mid_conv_128_128_b = nn.Conv2d(128, 128, 3, padding=1)
        self.mid_conv_256_256_a = nn.Conv2d(256, 256, 3, padding=1)
        self.mid_conv_256_256_b = nn.Conv2d(256, 256, 3, padding=1)
        self.mid_conv_512_1024 = nn.Conv2d(512, 1024, 3, padding=1)
        self.mid_conv_1024_1024 = nn.Conv2d(1024, 1024, 3, padding=1)
        self.decoder_4 = UNet_up_block(prev_channel=self.encoder_4.output_channels,
        input_channel=self.mid_conv_1024_1024.out_channels, output_channel=256)
        self.decoder_3 = UNet_up_block(prev_channel=self.encoder_3.output_channels,
        input_channel=self.decoder_4.output_channels, output_channel=128)
        self.decoder_2 = UNet_up_block(prev_channel=self.encoder_2.output_channels,
        input_channel=self.decoder_3.output_channels, output_channel=64)
        self.decoder_1 = UNet_up_block(prev_channel=self.encoder_1.output_channels,
        input_channel=self.decoder_2.output_channels, output_channel=64)
        self.binary_last_conv = nn.Conv2d(64, num_classes, kernel_size=1)
        self.softmax = nn.Softmax(dim=1)
        self.forward = self.topologies[topology]

    def ENC_1_DEC_1(self, x_in):
        x1_cat = self.encoder_1(x_in)
        x1_cat_1 = self.dropout(x1_cat)
        x1 = self.max_pool(x1_cat_1)
        x_mid = self.mid_conv_64_64_a(x1)
        x_mid = self.activate(x_mid)
        x_mid = self.mid_conv_64_64_b(x_mid)
        x_mid = self.activate(x_mid)
        x_mid = self.dropout(x_mid)
        x = self.decoder_1(x1_cat, x_mid)
        x = self.binary_last_conv(x)
        # return the final vector and the corresponding softmax-ed prediction
        return x, self.softmax(x)

    def ENC_2_DEC_2(self, x_in):
        x1_cat = self.encoder_1(x_in)
        x1 = self.max_pool(x1_cat)
        x2_cat = self.encoder_2(x1)
        x2_cat_1 = self.dropout(x2_cat)
        x2 = self.max_pool(x2_cat_1)
        x_mid = self.mid_conv_128_128_a(x2)
        x_mid = self.activate(x_mid)
        x_mid = self.mid_conv_128_128_b(x_mid)
        x_mid = self.activate(x_mid)
        x_mid = self.dropout(x_mid)
        x = self.decoder_2(x2_cat, x_mid)
        x = self.decoder_1(x1_cat, x)
        x = self.binary_last_conv(x)
        # return the final vector and the corresponding softmax-ed prediction
        return x, self.softmax(x)

    def ENC_3_DEC_3(self, x_in):
        x1_cat = self.encoder_1(x_in)
        x1 = self.max_pool(x1_cat)
        x2_cat = self.encoder_2(x1)
        x2_cat_1 = self.dropout(x2_cat)
        x2 = self.max_pool(x2_cat_1)
        x3_cat = self.encoder_3(x2)
        x3 = self.max_pool(x3_cat)
        x_mid = self.mid_conv_256_256_a(x3)
        x_mid = self.activate(x_mid)
        x_mid = self.mid_conv_256_256_b(x_mid)
        x_mid = self.activate(x_mid)
        x_mid = self.dropout(x_mid)
        x = self.decoder_3(x3_cat, x_mid)
        x = self.decoder_2(x2_cat, x)
        x = self.decoder_1(x1_cat, x)
        x = self.binary_last_conv(x)
        # return the final vector and the corresponding softmax-ed prediction
        return x, self.softmax(x)

    def ENC_4_DEC_4(self, x_in):
        x1_cat = self.encoder_1(x_in)
        x1 = self.max_pool(x1_cat)
        x2_cat = self.encoder_2(x1)
        x2_cat_1 = self.dropout(x2_cat)
        x2 = self.max_pool(x2_cat_1)
        x3_cat = self.encoder_3(x2)
        x3 = self.max_pool(x3_cat)
        x4_cat = self.encoder_4(x3)
        x4_cat_1 = self.dropout(x4_cat)
        x4 = self.max_pool(x4_cat_1)
        x_mid = self.mid_conv_512_1024(x4)
        x_mid = self.activate(x_mid)
        x_mid = self.mid_conv_1024_1024(x_mid)
        x_mid = self.activate(x_mid)
        x_mid = self.dropout(x_mid)
        x = self.decoder_4(x4_cat, x_mid)
        x = self.decoder_3(x3_cat, x)
        x = self.decoder_2(x2_cat, x)
        x = self.decoder_1(x1_cat, x)
        x = self.binary_last_conv(x)
        # return the final vector and the corresponding softmax-ed prediction
        return x, self.softmax(x)

Hi sjramen!

Yes, in order to process inputs with different numbers of channels, the most
straightforward thing to do is to modify / replace just the first convolutional
layer in the U-Net with one that accepts your desired number of channels.

I think it will be simplest to do this by modifying the U-Net after loading it
into memory, rather than trying to edit the .pt file.

Conceptually this is easy – just replace the first Conv2d layer with one with
the correct number of in_channels. How to do this in practice will depend
on the actual structure of your U-Net – how the layers are packaged, whether
its first layer involves a non-trivial forward() function, and so on.

Let me outline three things: How to replace the first layer; how to initialize the
new first layer with partially-pre-trained weights; and possible approaches to
fine-tuning the modified pre-trained model.

First to replace the initial layer: Your model will contain a python reference
to, or some sort of Module or ModuleList, etc., that contains a reference
to that first layer. You have to “find” it – that is, figure out how to navigate
through the structure of the model to access that reference. Then you simply
set that reference to a newly-instantiated Conv2d with the correct numbers of
in_channels.

At this point all the other weights in the model will have their pre-trained weights.
The new Conv2d will, however, have randomly-initialized weights. They won’t
match up well with the pre-trained weights.

You might or might not be able to use the pre-trained weights of the original
first Conv2d layer to generate useful weights for the new Conv2d layer. You
can’t just reuse the original weight tensor unchanged because it would be
the wrong shape. If you do well enough in initializing the new weight tensor
with “partially” pre-trained weights, you could plausibly reuse the bias tensor.
(It’s shape won’t change.) One possibility would be to initialize each channel
of the new weight tensor identically with the channel-average of the original
weight tensor. Whether, or how much, this will help (relative to random
initialization) would be hard to say without testing, but I think that it’s unlikely
to hurt.

I use a toy “U-Net” model to illustrate how to perform “surgery” on an existing
model. To make it realistic, the first Conv2d is packaged in a custom Module
as you might see in real U-Net implementation. We first “find” that first layer,
“navigate” to it, replace it, and then replace the new first layer’s randomly
initialized weights with weights derived from the pre-trained model. Note, you
should perform any such surgery before adding your model’s parameters to
an optimizer.

The following script demonstrates how to perform such surgery by converting
a three-input-channel model into an eight-input-channel model:

import torch
print (torch.__version__)

_ = torch.manual_seed (2022)

# simplified version of the first layer in a u-net
class ConvLayer (torch.nn.Module):
    def __init__ (self):
        super().__init__()
        self.layer = torch.nn.Sequential (
            torch.nn.Conv2d (3, 16, 3, padding = 'same'),
            torch.nn.ReLU()
        )
    def forward (self, x):
        return  self.layer (x)

# pretend version of a u-net
class Model (torch.nn.Module):
    def __init__ (self):
        super().__init__()
        self.seq = torch.nn.Sequential (
            ConvLayer(),
            torch.nn.Conv2d (16, 32, 3, padding = 'same'),            
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear (32*28*28, 2)   # accept 28x28 images, output two classes
        )
    def forward (self, x):
        return  self.seq (x)

# instantiate Model -- pretend it's a pre-trained u-net
model = Model()
input = torch.randn (2, 3, 28, 28)   # batch of 2 3-channel 28x28 "images"
output = model (input)               # batch of 2 2-class predictions
print ('output:')
print (output)

# attempt to apply original model to 8-channel input -- fails
input8 = torch.randn (2, 8, 28, 28)   # 8-channel input
try:
    output8 = model (input8)
    print ('model (input8) succeeded')
    print ('output8:')
    print (output8)
except:
    print ('model (input8) failed')

# look at model (to "locate" first layer)
print ('model:')
print (model)

# perform surgery on "pre-trained u-net" model

# keep reference to original Conv2d layer to initialize weights of replacement
conv2d_save = model.seq[0].layer[0]

# replace initial Conv2d layer with new Conv2d that accepts 8-channel input
# the new Conv2d is randomly initialized
model.seq[0].layer[0] = torch.nn.Conv2d (8, 16, 3, padding = 'same')

# apply modified model to 8-channel input -- succeeds
try:
    output8 = model (input8)
    print ('model (input8) succeeded')
    print ('output8:')
    print (output8)
except:
    print ('model (input8) failed')

# initialize new Conv2d layer with channel-average of pre-trained weights
# possibly better than random initialization
with torch.no_grad():
    model.seq[0].layer[0].weight.copy_ (conv2d_save.weight.mean (dim = 1, keepdim = True).expand (16, 8, 3, 3))
    model.seq[0].layer[0].bias.copy_ (conv2d_save.bias)

# 8-channel model with (partially) pre-trained first-layer weights still works
try:
    output8 = model (input8)
    print ('model (input8) succeeded')
    print ('output8:')
    print (output8)
except:
    print ('model (input8) failed')

Here is its output:

1.12.0
output:
tensor([[-0.0413,  0.0124],
        [ 0.0369, -0.0403]], grad_fn=<AddmmBackward0>)
model (input8) failed
model:
Model(
  (seq): Sequential(
    (0): ConvLayer(
      (layer): Sequential(
        (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
        (1): ReLU()
      )
    )
    (1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (2): ReLU()
    (3): Flatten(start_dim=1, end_dim=-1)
    (4): Linear(in_features=25088, out_features=2, bias=True)
  )
)
model (input8) succeeded
output8:
tensor([[-0.0129,  0.0531],
        [-0.0595,  0.0709]], grad_fn=<AddmmBackward0>)
model (input8) succeeded
output8:
tensor([[-0.0301,  0.0707],
        [ 0.0079, -0.0573]], grad_fn=<AddmmBackward0>)

Last, your new model’s weights – at least those of the first few layers – won’t
really match your problem as well as they could. The number of input channels
will be different, but also the meaning of those channels and how they relate to
one another will be different, so you want to improve the pre-trained weights
that you borrowed from the similar, but distinct, pre-trained model.

I’ve never fine-tuned a pre-trained U-Net model. My intuition suggests that
it might be best to first freeze all of the layers except the first, fine-tune the
first layer for a bit (that is, run the forward-backward-optimize loop for some
number of epochs, but only optimize the first layer), then maybe fine-tune
the first few layers for a while, and then finish off by fine-tuning the whole
model.

The idea is that the first layer is somewhat out of whack, but the rest of the
model is still quite good. So you want the first layer to adapt itself to the rest
of the model more than you want the rest of the model to adapt itself to the
out-of-whack first layer. But after the first few layers are about right, you do
want to fine-tune the whole model to fit you specific use case.

Best.

K. Frank