I need to load a checkpoint with four channels into an otherwise identical model with three channels and ignore say the 4th channel.

I can’t train the model again since the compute is cost prohibitive. It seems like a fairly common issue with say RGBA <> RGB. It would also be handy for me to know how to use a single channel.

I tried walking over the `state_dict`

:

```
model.load_state_dict(dict([(n, p) for n, p in checkpoint['state_dict'].items()]), strict=False)
```

But still get a mismatch error:

```
size mismatch for model.vision_model.out.2.weight: copying a param with shape torch.Size([4, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 512, 3, 3]).
size mismatch for model.vision_model.out.2.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([3]).
```

I obviously need to walk each channel but not sure what the best way is. I appreciate any suggestions.

I have very constrained runtime compute and when considering all possible optimizations, this would be the biggest improvement I could make because 25% of inference compute is completely wasted now but having the unneeded weights in GPU memory is my main concern.