I am modifying pre-trained models so that they accept 2 input images. I tested my code with the vgg11_bn, resnet50 and vit_b_16 nets and all of them worked pretty well. However, when using a similar approach with the inception_v3, I am getting the following error:
RuntimeError: Given groups=1, weight of size [32, 6, 3, 3], expected input[1, 3, 299, 299] to have 6 channels, but got 3 channels instead
To replicate the problem:
import torch
from torchvision import models
model = models.inception_v3(pretrained=True)
first_layer = model.Conv2d_1a_3x3.conv
weight = first_layer.weight.clone()
# Modify the input layer to take 2 images as input (i.e., 6 channels instead of 3)
first_layer = torch.nn.Conv2d(6, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
# Copy weights channelwise:
with torch.no_grad():
first_layer.weight[:, :3] = weight
first_layer.weight[:, 3:] = weight
model.Conv2d_1a_3x3.conv = first_layer
inputs = torch.randn(2, 6, 299, 299) # using batch size 2 to avoid problems with batch norm layers
model(inputs)
Indeed it needs a batch size greater than 1. I edited this in the question, thank you for the heads up!
Anyway, the problem I’m having is not due to the batch size. I also tested this code on another machine and got the same error.
Could you please try to replicate the package versions to see if the problem occurs?
I have a requirements.txt file with the following content:
matplotlib==3.5.1
numpy==1.21.5
pandas==1.4.2
Pillow==9.2.0
torch==1.12.1
torchmetrics==0.10.0
torchvision==0.13.1
tqdm==4.64.0
As you noticed I’m using torchvision==0.13.1. The previous code prints a deprecation warning. With this version of torchvision, the proper way to load the model now is:
weights = models.Inception_V3_Weights.IMAGENET1K_V1
model = models.inception_v3(weights=weights)
However, it doesn’t fix the problem. The error is still raised. RuntimeError: Given groups=1, weight of size [32, 6, 3, 3], expected input[1, 3, 299, 299] to have 6 channels, but got 3 channels instead
You are right and I am able to reproduce the issue when using a pretrained model (my previous code snippet used a randomly initialized model).
The issue is caused by the built-in transformation, which normalizes and slices the first three channels.
Use model = models.inception_v3(pretrained=True, transform_input=False) and it should work.