Hi pytorch community,
For a research project, I am trying to apply different sets of filters to elements in the batch dimension.
That is, I want to group the batch into G
different sets and apply different filters to the batch dim group, in parallel. The extreme case where the number of groups G=N
equals the batch size would mean we have a separate filter for each batch element.
input dimension: N x C x H x W
filter dimension: G x C_out x C_in x H_filter x W_filter
output dimension: N x C_out x H_out x W_out
I believe this can be done by temporarily treating batch elements as channels and using the groups=
argument of Conv2d to do so. However, it does not yield the correct result.
The code below seems to work for G=1
(normal convolution with same filter applied to all batch elements) and G=N
(different filter for each element in the batch). However, it does not seem to work for some G
in between 1 and N.
N, C, H, W = 128, 3, 28, 28
CC = 64
for groups in (1, 2, N):
# dummy data
x = torch.randn(N, C, H, W)
w = torch.randn(N, CC, C, 7, 7)
sub_size = N // groups
x = torch.randn(N, C, H, W)
w = torch.randn(groups, CC, C, 7, 7)
# desired computation
out = []
for i in range(groups):
sub_x = x[i*sub_size:i*sub_size + sub_size]
sub_w = w[i]
sub_out = F.conv2d(sub_x, sub_w)
out.append(sub_out)
out = torch.cat(out)
# first attempt exploiting groups to perform the computation in parallel
w_fast = w.view(groups*CC, C, 7, 7)
x_fast = x.view(sub_size, groups*C, H, W)
out_grouped = F.conv2d(x_fast, w_fast, groups=groups)
out_grouped = out_grouped.view(N, CC, *out_grouped.shape[-2:])
# should be the same:
print((out - out_grouped).abs().max())
>> tensor(0.) # <-- good
>> tensor(91.9094) # <-- bad, G=2 fails.
>> tensor(4.5776e-05) # <-- good
What am I missing here?
Many thanks in advance!
Tycho