Hi pytorch community,

For a research project, I am trying to apply different sets of filters to elements in the batch dimension.

That is, I want to group the batch into `G`

different sets and apply different filters to the batch dim group, in parallel. The extreme case where the number of groups `G=N`

equals the batch size would mean we have a separate filter for each batch element.

input dimension: `N x C x H x W`

filter dimension: `G x C_out x C_in x H_filter x W_filter`

output dimension: `N x C_out x H_out x W_out`

I believe this can be done by temporarily treating batch elements as channels and using the `groups=`

argument of Conv2d to do so. However, it does not yield the correct result.

The code below seems to work for `G=1`

(normal convolution with same filter applied to all batch elements) and `G=N`

(different filter for each element in the batch). However, it does not seem to work for some `G`

in between 1 and N.

```
N, C, H, W = 128, 3, 28, 28
CC = 64
for groups in (1, 2, N):
# dummy data
x = torch.randn(N, C, H, W)
w = torch.randn(N, CC, C, 7, 7)
sub_size = N // groups
x = torch.randn(N, C, H, W)
w = torch.randn(groups, CC, C, 7, 7)
# desired computation
out = []
for i in range(groups):
sub_x = x[i*sub_size:i*sub_size + sub_size]
sub_w = w[i]
sub_out = F.conv2d(sub_x, sub_w)
out.append(sub_out)
out = torch.cat(out)
# first attempt exploiting groups to perform the computation in parallel
w_fast = w.view(groups*CC, C, 7, 7)
x_fast = x.view(sub_size, groups*C, H, W)
out_grouped = F.conv2d(x_fast, w_fast, groups=groups)
out_grouped = out_grouped.view(N, CC, *out_grouped.shape[-2:])
# should be the same:
print((out - out_grouped).abs().max())
>> tensor(0.) # <-- good
>> tensor(91.9094) # <-- bad, G=2 fails.
>> tensor(4.5776e-05) # <-- good
```

What am I missing here?

Many thanks in advance!

Tycho