Hi pytorch community,
For a research project, I am trying to apply different sets of filters to elements in the batch dimension.
That is, I want to group the batch into
G different sets and apply different filters to the batch dim group, in parallel. The extreme case where the number of groups
G=N equals the batch size would mean we have a separate filter for each batch element.
N x C x H x W
G x C_out x C_in x H_filter x W_filter
N x C_out x H_out x W_out
I believe this can be done by temporarily treating batch elements as channels and using the
groups= argument of Conv2d to do so. However, it does not yield the correct result.
The code below seems to work for
G=1 (normal convolution with same filter applied to all batch elements) and
G=N (different filter for each element in the batch). However, it does not seem to work for some
G in between 1 and N.
N, C, H, W = 128, 3, 28, 28 CC = 64 for groups in (1, 2, N): # dummy data x = torch.randn(N, C, H, W) w = torch.randn(N, CC, C, 7, 7) sub_size = N // groups x = torch.randn(N, C, H, W) w = torch.randn(groups, CC, C, 7, 7) # desired computation out =  for i in range(groups): sub_x = x[i*sub_size:i*sub_size + sub_size] sub_w = w[i] sub_out = F.conv2d(sub_x, sub_w) out.append(sub_out) out = torch.cat(out) # first attempt exploiting groups to perform the computation in parallel w_fast = w.view(groups*CC, C, 7, 7) x_fast = x.view(sub_size, groups*C, H, W) out_grouped = F.conv2d(x_fast, w_fast, groups=groups) out_grouped = out_grouped.view(N, CC, *out_grouped.shape[-2:]) # should be the same: print((out - out_grouped).abs().max()) >> tensor(0.) # <-- good >> tensor(91.9094) # <-- bad, G=2 fails. >> tensor(4.5776e-05) # <-- good
What am I missing here?
Many thanks in advance!