Error after adding layer to downloaded model

Hi,
I am using the DeepLabV3_ResNet50 model for a project and i want to build a classifier with this network. So I tried to add a Flatten layer at the end of the network using this command model2.classifier.add_module('fl', torch.nn.Flatten()), in order to add later some fully connected layers. But after adding this layer, when I try to pass a dummy input input = torch.ones(1, 3, 112, 112), I get the following error:

Traceback (most recent call last):
  File "/usr/lib/python3.8/code.py", line 90, in runcode
    exec(code, self.locals)
  File "<input>", line 1, in <module>
  File "/home/spbtu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/spbtu/.local/lib/python3.8/site-packages/torchvision/models/segmentation/_utils.py", line 25, in forward
    x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
  File "/home/spbtu/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 3079, in interpolate
    raise ValueError('size shape must match input shape. '
ValueError: size shape must match input shape. Input is 0D, size is 2

Is something wrong with the way I add the flattening layer?

nn.Flatten would use the default arguments:

start_dim=1, end_dim=-1

so the batch dimension (dim0) should stay the same while all additional dimensions would be flattened.

Based on the error message it seems you are trying to apply a 2D interpolation on this flattened tensor, which won’t work:

x = torch.randn(1, 3, 4, 4)
out = F.interpolate(x, size=(10, 10)) # works

y = x.view(x.size(0), -1)
out = F.interpolate(y, size=(10, 10)) # error
> ValueError: size shape must match input shape. Input is 0D, size is 2

out = F.interpolate(torch.nn.Flatten()(x), size=(10, 10)) # error
> ValueError: size shape must match input shape. Input is 0D, size is 2

since you would need to use a 3-, 4, or 5-dimensional tensor as given in the code snippet ([batch_size, channels, seq_len], [batch_size, channels, height, width], or [batch_size, channels, depth, height, width]).