RuntimeError: expected stride to be a single integer value or a list of 1 values to match the convolution dimensions, but got stride=[1, 1] in case of self assigned weights

So I went through some of the threads and I was able get rid of some issues. But, I am still facing this error

RuntimeError: expected stride to be a single integer value or a list of 1 values to match the convolution dimensions, but got stride=[1, 1]

In my case, I have assigned my own weights and I am using a 3rd party library to just see how the network is working (torchinfo ). I am giving the input as (1,1,28,28) . My model is as follows:

net(
  (conv1): Conv2d(1, 1, kernel_size=(2, 2), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(1, 1, kernel_size=(2, 2), stride=(1, 1))
)

Any suggestions on getting rid of the error?

Could you post an executable code snippet using random inputs, which would reproduce this issue, please?

--------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~/anaconda3/envs/pytorch/lib/python3.8/site-packages/torchinfo/torchinfo.py in summary(model, input_size, input_data, batch_dim, col_names, col_width, depth, device, dtypes, verbose, **kwargs)
    159                 if isinstance(x, (list, tuple)):
--> 160                     _ = model.to(device)(*x, **kwargs)
    161                 elif isinstance(x, dict):

~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(

<ipython-input-94-b91d7160d82c> in forward(self, x)
     11     def forward(self, x):
---> 12         x = self.pool(F.relu(self.conv1(x)))
     13 #         x = self.pool(F.relu(self.conv2(x)))

~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(

~/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
    422     def forward(self, input: Tensor) -> Tensor:
--> 423         return self._conv_forward(input, self.weight)
    424 

~/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
    418                             _pair(0), self.dilation, self.groups)
--> 419         return F.conv2d(input, weight, self.bias, self.stride,
    420                         self.padding, self.dilation, self.groups)

RuntimeError: expected stride to be a single integer value or a list of 1 values to match the convolution dimensions, but got stride=[1, 1]

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-97-f52bda561c0c> in <module>
----> 1 summary(net,(1,1,28,28))

~/anaconda3/envs/pytorch/lib/python3.8/site-packages/torchinfo/torchinfo.py in summary(model, input_size, input_data, batch_dim, col_names, col_width, depth, device, dtypes, verbose, **kwargs)
    167         except Exception as e:
    168             executed_layers = [layer for layer in summary_list if layer.executed]
--> 169             raise RuntimeError(
    170                 "Failed to run torchinfo. See above stack traces for more details. "
    171                 f"Executed layers up to: {executed_layers}"

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

That’s unfortunately not a code snippet, but just the error message with the complete stack trace.

My sincerest apologies for the goof-up. I don’t know what I was thinking. This is how I initialised weights of my convolution layer:

conv1 = nn.Conv2d(1, 1, kernel_size=2)
with torch.no_grad():
  conv1.weight.data = torch.tensor([[-0.8423,  0.3778],[-3.1070, -2.6518]]) 

And then, I just defined the model like this:

class prune_net(nn.Module):
    def __init__(self):
        super(prune_net, self).__init__()
        self.conv1 = conv1
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        return x

Thanks for the code snippet.
The simple model is running fine without the modification of the conv1.weight and doesn’t yield any errors.
When you are running the conv1.weight.data assignment, you should get an error as:

TypeError: cannot assign 'torch.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

After fixing it, you would run in another error, since the weight shapes are not matching that of an nn.Conv2d layer, so you would need to unsqueeze two dimensions.
Also, you shouldn’t use the .data attribute, as it might cause issues.

Here is the fixed code, which doesn’t raise the initial error, so I guess some parts might still be missing in order to reproduce this issue:

conv1 = nn.Conv2d(1, 1, kernel_size=2)
with torch.no_grad():
    conv1.weight = nn.Parameter(torch.tensor([[[[-0.8423,  0.3778],[-3.1070, -2.6518]]]]))
  

class prune_net(nn.Module):
    def __init__(self):
        super(prune_net, self).__init__()
        self.conv1 = conv1
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        return x

model = prune_net()
x = torch.randn(1, 1, 24, 24)
out = model(x)
print(out.shape)
> torch.Size([1, 1, 11, 11])
1 Like

Could this be a version issue? I am running 1.7. Which one are you running?

Bingo. It’s working now. So what’s the issue with my code? If I initialize it the way I did, do I have to convert to nn.Parameter type for it to run?

Yes, you have to assign an nn.Parameter to the .weight attribute.
However, I don’t know what exactly caused your issue, as I was getting an error while executing your code (as described before).
I was using the latest nightly release and haven’t tried 1.7.