RuntimeError: expected stride to be a single integer value or a list of 3 values to match the convolution dimensions, but got stride=[1, 1], but dimensions fits

Hey all,
I got following error:

RuntimeError: expected stride to be a single integer value or a list of 3 values to match the convolution dimensions, but got stride=[1, 1]

all I’ve found is to use unqueeze, but I already did. I’m just wondering because I my training setup everything works fine.

Train:

# Write a function that determines a threshold for input/output neurons to be set to zero
# (the ones which have been reduced in absolute value using the sparsity constraint).
# You can use the function topk (similar to nth_element in C++),
# which outputs both the values and indices sorted around a chosen quantile/percentile.
# Here we simply use the median to set 50% of values to zero.
# When applied correctly (as incoming & outgoing mask) for each Conv2d layer, 
# it reduces the nonzero parameters by ~75% (the first incoming & last outgoing Conv2d are not masked).
# Note that BatchNorm has four tensors and two index masks have to be applied as follows: 
#     B = A[idx_next,:,:,:][:,idx_prev,:,:]
# Now you can replace all Conv2d and BatchNorm2d layers with smaller filters 
# (and copy their weights) so that we have the following sequence of channels:
# 3, 32, 64, (2x)128, (3x)256, 512  
def flop50k(net_trained,percentOut):
    percentIn=1-percentOut
    for feature_idx in range(29):
        if isinstance(net_trained.features[feature_idx],(nn.BatchNorm2d)):
            if feature_idx != 26:
                sizeIn = int(net_trained.features[feature_idx].bias.size(0)*percentIn)
                #layerBias = net_trained.features[feature_idx].bias
                layerWeight = net_trained.features[feature_idx].weight
                #bias_flopk = torch.topk(layerBias,sizeIn, largest=True)[1]
                weight_flopk = torch.topk(layerWeight,sizeIn, largest=True)[1].sort()[0]
                #net_trained.features[feature_idx].weight[weight_flopk] = 0
                #net_trained.features[feature_idx].bias[bias_flopk] = 0
                #net_trained.features[feature_idx-1].weight[bias_flopk,:,:,:] = 0
                #net_trained.features[feature_idx-1].bias[bias_flopk,:,:,:] = 0
                #save params of BatchNorm where the weights are above the treshold
                #totalparams = net_trained.features[feature_idx].total_params
                weightBatch = net_trained.features[feature_idx].weight[weight_flopk]
                biasBatch = net_trained.features[feature_idx].bias[weight_flopk]
                runnngmeanBatch = net_trained.features[feature_idx].running_mean[weight_flopk]
                runningvarBatch = net_trained.features[feature_idx].running_var[weight_flopk]
                #creating slimmer BatchNorm
                net_trained.features[feature_idx] = nn.BatchNorm2d(int(sizeIn), 
                                                           eps=1e-05, momentum=0.1,
                                                           affine=True, 
                                                           track_running_stats=True)
                net_trained.features[feature_idx].weight.data = weightBatch
                net_trained.features[feature_idx].bias.data = biasBatch
                net_trained.features[feature_idx].running_mean.data = runnngmeanBatch
                net_trained.features[feature_idx].running_var.data = runningvarBatch
            #adapt the joint Cov2d
            weightConv = net_trained.features[feature_idx-1].weight
            biasConv = net_trained.features[feature_idx-1].bias
            # size(0) outchannel size(1) inchannel
            sizeConv = weightConv.size()
            #create new Conv layer
            #don't reduce the first input channel size
            if feature_idx == 1:
                net_trained.features[feature_idx-1] = nn.Conv2d(sizeConv[1], 
                                                        int(sizeConv[0]*percentIn), 
                                                        kernel_size=3, 
                                                        stride=1, 
                                                        padding=1)
                net_trained.features[feature_idx-1].weight.data = weightConv[weight_flopk,:,:,:]
                net_trained.features[feature_idx-1].bias.data = biasConv[weight_flopk]
            elif feature_idx == 26:
                #Batch und conv auf 512
                net_trained.features[feature_idx-1] = nn.Conv2d(int(sizeConv[1]*percentIn), 
                                                        512, 
                                                        kernel_size=3, 
                                                        stride=1, 
                                                        padding=1)
                net_trained.features[feature_idx-1].weight.data = weightConv[:,weight_flopk_prev.view(1,-1),:,:]
                net_trained.features[feature_idx-1].bias.data = biasConv[weight_flopk]
            else:
                net_trained.features[feature_idx-1] = nn.Conv2d(int(sizeConv[1]*percentIn), 
                                                        int(sizeConv[0]*percentIn), 
                                                        kernel_size=3, 
                                                        stride=1, 
                                                        padding=1)
                net_trained.features[feature_idx-1].weight.data = weightConv[weight_flopk.view(-1,1),weight_flopk_prev.view(1,-1),:,:]
                net_trained.features[feature_idx-1].bias.data = biasConv[weight_flopk]
            #save the index for the inchannels of the next layer
            weight_flopk_prev = weight_flopk
    return net_trained
#net2.eval()
net3=copy.deepcopy(net2)
with torch.no_grad():
    # SAVE THE TRAINED MODEL
    net4 = flop50k(net3,0.5)
    print(net3)
    print(net4)
    torch.save(net4.cpu(),'mdl_4_net4.pth')

Test:
# Evaluate the slimmed network (you could observe a slight improvement to ~92%)
# and confirm that the required computations are reduced to 12 GFlops
with torch.no_grad():
    net4 = torch.load('mdl_4_net4.pth')
    net4.eval()
    test_num = img_test.size(0)
    test_acc = 0.0
    # LOOP over the testdata set
    for i in range(test_num):
        test_img = img_test[i,:,:,:].unsqueeze(0)
        print(test_img.size())
        test_label =label_test[0,i].cuda()
        print(test_label.size())
        output = net4(test_img).cuda()
        accuracy = torch.mean((output.argmax(1)==test_label).float()).cuda()
        test_acc += accuracy
    test_acc /= test_num
    print('Test accuracy:', test_acc)

Which line of code throws this error?
Could you post the definition of this layer once you’ve found it and if possible the input shapes?
If you can’t locate the source of this error, could you instead post a code snippet to reproduce this issue, so that we can have a look?

PS: I’ve formatted your code for better readability. You can add code snippets by wrapping them in three backticks ``` :wink:

Thank you for the hint :wink:
My net is defined as followed:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (16): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (17): ReLU(inplace)
    (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (20): ReLU(inplace)
    (21): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (22): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (23): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (27): ReLU(inplace)
    (28): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (classifier): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=2, bias=True)
  )
)

which works. For compression I want to reduce the size of feature channels by 50%. Afterwards the net looks like this.

VGG(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace)
    (11): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (16): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (17): ReLU(inplace)
    (18): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (20): ReLU(inplace)
    (21): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (22): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (23): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (24): ReLU(inplace)
    (25): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (27): ReLU(inplace)
    (28): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (classifier): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=2, bias=True)
  )
)

The error is thrown by calling the sliced net in the test mode:
output = net4(test_img).cuda()

RuntimeError                              Traceback (most recent call last)
<ipython-input-32-f32bcd27b8a3> in <module>()
     10         test_label =label_test[0,i].cuda()
     11         print(test_label.size())
---> 12         output = net4(test_img).cuda()
     13         accuracy = torch.mean((output.argmax(1)==test_label).float()).cuda()
     14         test_acc += accuracy

5 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in forward(self, input)
    336                             _pair(0), self.dilation, self.groups)
    337         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 338                         self.padding, self.dilation, self.groups)
    339 
    340 

RuntimeError: expected stride to be a single integer value or a list of 3 values to match the convolution dimensions, but got stride=[1, 1]

img_test:
torch.Size([16384, 3, 48, 48])
img_train
torch.Size([65536, 3, 48, 48])

Both model definitions seem to be identical.
What did you change in the second model?

Oops sry, I changed it.

Okay I found the mistake.
net_trained.features[feature_idx-1].weight.data = weightConv[weight_flopk.view(-1,1),weight_flopk_prev.view(1,-1),:,:]
should slice my net. For example [128,64,3,3] to [64,32,3,3] with weight_flopk containing 64 indices and weight_flopk_prev containing 32, but gut a tensor of [64,64,3,3]. How would I select the dimensions correctly?

The additional dimension in weight_flopk_prev might cause some trouble.
Here is a small example:

weight = torch.randn(10, 6, 3, 3)
idx0 = torch.randint(0, 10, (5,))
idx1 = torch.randint(0, 6, (3,))

new_weight = weight[idx0.unsqueeze(1), idx1]
print(new_weight.shape)
> torch.Size([5, 3, 3, 3])

Also, if you are trying to manipulate the parameters directly, warp it in a with torch.no_grad() block and don’t use .data as this might yield to wrong calculations.