Ignore extra channel weights when loading a checkpoint

I need to load a checkpoint with four channels into an otherwise identical model with three channels and ignore say the 4th channel.

I can’t train the model again since the compute is cost prohibitive. It seems like a fairly common issue with say RGBA <> RGB. It would also be handy for me to know how to use a single channel.

I tried walking over the state_dict:

model.load_state_dict(dict([(n, p) for n, p in checkpoint['state_dict'].items()]), strict=False)

But still get a mismatch error:

size mismatch for model.vision_model.out.2.weight: copying a param with shape torch.Size([4, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 512, 3, 3]).
size mismatch for model.vision_model.out.2.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([3]).

I obviously need to walk each channel but not sure what the best way is. I appreciate any suggestions.

I have very constrained runtime compute and when considering all possible optimizations, this would be the biggest improvement I could make because 25% of inference compute is completely wasted now but having the unneeded weights in GPU memory is my main concern.

I’m a bit confused about this statement as I would assume the only difference would be in the inputs having a different number of channels (3 vs. 4).
In this case, only the first conv layer would use a single additional channel, which might be a tiny compute and memory overhead compared to every other layer and operation in this model.
E.g. check this small example:

# default RGB
modelA = models.resnet50()
x = torch.randn(1, 3, 224, 224)
out = modelA(x) # works

# check number of parameters
nb_paramsA = sum([p.nelement() for p in modelA.parameters()])

# change to RGBA inputs
modelB = models.resnet50()
modelB.conv1 = nn.Conv2d(4, 64, 7, 2, 3, bias=False)
x = torch.randn(1, 4, 224, 224)
out = modelB(x) # works

# check number of parameters
nb_paramsB = sum([p.nelement() for p in modelB.parameters()])

# compare
print((nb_paramsB - nb_paramsA) / nb_paramsB)
# 0.00012269089937124044

In any case, you could manipulate the state_dict directly by dropping the unwanted channel before calling load_state_dict.

Wow, thank you for the fascinating example, I’m surprised! I was under the impression that each channel in the model had a set of weights.

IIRC inference is done on a per channel basis though, right? If that is the case then extra channels would still less than ideal.

Do you know of any examples showing how to drop channels from a state_dict?

The weight attribute of a conv layer has the shape [out_channels, in_channels, height, width]. Since you are defining the out_channels of a conv layer, only the very first one uses the actual input channels of your input tensors.
I.e. look at the first layers of the resnet:

print(modelA)
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )

As you can see, the output activation of the first conv layer will have 64 channels. The next layers thus expect activations with 64 channels and will define their own out_channels values (e.g. the next two conv layers still outputs 64 channels, while the one afterwards increases it to 256).

No, that’s not the case. Each layer will process the entire batch and output the specified activation. The number of channels of your input is only important for the very first conv layer.

Here is one:

# fails
sd = modelB.state_dict()
modelA.load_state_dict(sd)
# RuntimeError: Error(s) in loading state_dict for ResNet:
#  size mismatch for conv1.weight: copying a param with shape torch.Size([64, 4, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 3, 7, 7]).

print(sd['conv1.weight'].shape)
# torch.Size([64, 4, 7, 7])

# slice channels
sd['conv1.weight'] = sd['conv1.weight'][:, :3, ...]

# works
modelA.load_state_dict(sd)
# <All keys matched successfully>