Possible strange bug on F.conv2d

When playing around F.conv2d, I find some unexpected results.
Basically, giving a set of filter weights, say size (3,1,4,4), it will take 1 channel input and produce 3 channels output, and the 3 filters in the given weights are considered independent when doing convolution. Therefore, without bias and having stride and padding conditions same,
doing:

F.conv2d( input, weight, *kwargs ) 

should get the same result as doing:

F.conv2d( input, weight[0].unsqueeze(0), *kwargs ) 
F.conv2d( input, weight[1].unsqueeze(0), *kwargs )
F.conv2d( input, weight[2].unsqueeze(0), *kwargs )

But the output in pytorch is strange

Here is the code to reproduce the strange behavior:

import torch
import torch.nn.functional as F

inp = torch.randn(1,1,4,4)

stride = 2
padding = 1

weight = torch.randn(3, 1, 4, 4)
p1 = weight[:, :, :2, :]
p2 = weight[:, :, 2:, :]

test = F.conv2d( inp, p2, bias=None, stride=(stride, stride))

test_0 = F.conv2d( inp, p2[0].unsqueeze(0), bias=None, stride=(stride, stride))
test_1 = F.conv2d( inp, p2[1].unsqueeze(0), bias=None, stride=(stride, stride))
test_2 = F.conv2d( inp, p2[2].unsqueeze(0), bias=None, stride=(stride, stride))

test_cat = torch.cat([test_0, test_1, test_2], dim=0)

the variable “test” should be same as “test_cat” above, but the results turns out to be test[0] = test_0, test[3]=test_2, while test[2] and test_1 just some other random values.
Also, this behavior happens only when I do the cut

p2 = weight[:,:,2:,:]

If I just use “weight” the do all the testing, it works fine.
Can anyone kindly help, am I missing something and made some mistakes here?

Also, as I tested, the result from “test_cat” is actually the desired convolution result, while the “test” one is strange

You may be seeing a bug from an older version of PyTorch. Make sure your PyTorch is up to date (examine torch.__version__).

Also the torch.cat call should be along dim 1, not dim 0. (dim 0 is batch). With that fixed, test and test_cat match for me (PyTorch 1.0.1.post2):

...
test_cat = torch.cat([test_0, test_1, test_2], dim=1)  # NOTE dim=1
torch.allclose(test, test_cat)  # True
1 Like

Great, thank you! I will update my torch version.