Why torchvision.models can not forward the input which has size of larger than 430*430

I use different size of inputs to forward the torchvision.models.resnet18.
It is OK, when the size is smaller than 400. But when I change the size to 430*430, it raise a error
’RuntimeError: size mismatch, m1: [1 x 2048], m2: [512 x 1000] at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1229’
How can I use a larger input with torchvision.models.

Can you try different input sizes (e.g. 410, 412, 420, 426, …)? Sometimes only multiples of a specific resolution are compatible else there are mismatches between layers for example because of padding.

Hi @bodokaiser
Yes, I have tried many input sizes. And when the size is smaller than 416, it has a output, when the size is larger than 415, it has an error 'RuntimeError: size mismatch, m1: [1 x 2048], m2: [512 x 1000] at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1229’
And when I tried much larger size like 1000, the error changed to 'RuntimeError: size mismatch, m1: [1 x 8192], m2: [512 x 1000] at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1229’
It also happened to other model like resnet34.

Can you print the complete error stack trace?

The problem is probably in the fully-connected layers, as it was expecting some feature map of a specific size and now it’s bigger.
You can bypass that by adding a Adaptive[Max/Average]Pooling2d layer just before the classifier, so that the output size is the same and the classifier can work.

@bodokaiser,
my code is

transforms = transforms.Compose([
        transforms.CenterCrop(418),
        transforms.ToTensor(),
    ])
dset = datasets.ImageFolder(os.path.join('./data/'), transforms)
dset_loader = torch.utils.data.DataLoader(dset, batch_size=1)
inputs, classes = next(iter(dset_loader))

model_resnet18 = resnet18(pretrained=False)
input_test = Variable(inputs)
print(input_test.size(), input_test)
out = model_resnet18(input_test)

output of input_test.size() is torch.Size([1, 3, 418, 418])

And the complete error stack trace is:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-b764430b1113> in <module>()
      2 input_test = Variable(torch.Tensor(input_np))
      3 print(input_test.size(), input_test)
----> 4 out = model_resnet18(input_test)
      5 print(out)

/usr/share/Anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.pyc in __call__(self, *input, **kwargs)
    204 
    205     def __call__(self, *input, **kwargs):
--> 206         result = self.forward(*input, **kwargs)
    207         for hook in self._forward_hooks.values():
    208             hook_result = hook(self, input, result)

/usr/share/Anaconda2/lib/python2.7/site-packages/torchvision/models/resnet.pyc in forward(self, x)
    148         x = self.avgpool(x)
    149         x = x.view(x.size(0), -1)
--> 150         x = self.fc(x)
    151 
    152         return x

/usr/share/Anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.pyc in __call__(self, *input, **kwargs)
    204 
    205     def __call__(self, *input, **kwargs):
--> 206         result = self.forward(*input, **kwargs)
    207         for hook in self._forward_hooks.values():
    208             hook_result = hook(self, input, result)

/usr/share/Anaconda2/lib/python2.7/site-packages/torch/nn/modules/linear.pyc in forward(self, input)
     52             return self._backend.Linear()(input, self.weight)
     53         else:
---> 54             return self._backend.Linear()(input, self.weight, self.bias)
     55 
     56     def __repr__(self):

/usr/share/Anaconda2/lib/python2.7/site-packages/torch/nn/_functions/linear.pyc in forward(self, input, weight, bias)
      8         self.save_for_backward(input, weight, bias)
      9         output = input.new(input.size(0), weight.size(0))
---> 10         output.addmm_(0, 1, input, weight.t())
     11         if bias is not None:
     12             # cuBLAS doesn't support 0 strides in sger, so we can't use expand

RuntimeError: size mismatch, m1: [1 x 2048], m2: [512 x 1000] at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1229

@bodokaiser, @fmassa thanks a lot!
I found the problem.
It is because the kernal size of avgpool layer in resnet in the code is fixed

self.avgpool = nn.AvgPool2d(7)

so I think I should reimplement it by nn.AdaptiveAvgPool2d(1)

I think I also can redefine the avgpool in the forward function. I want to know is it feasible to redefine it when forward and why we should define all layer during init the module like most code.
Is there other way to define it in the init to have a changeable avgpool kernal size?

1 Like

Thanks, you are right!

facing the same problem here ,just wondering how did you modify the original resnet to take in the avgpool operation?

When do you want to do that?

as @XavierLinNow suggest you just need to use:

    self.pool2 = nn.AdaptiveAvgPool2d(1)

after all your ConvNets, therefore, in the end you will always get fix value of outputs.

2 Likes