Batched Conv2d with filters "grouped" in batch dimension

Hi pytorch community,

For a research project, I am trying to apply different sets of filters to elements in the batch dimension.
That is, I want to group the batch into G different sets and apply different filters to the batch dim group, in parallel. The extreme case where the number of groups G=N equals the batch size would mean we have a separate filter for each batch element.

input dimension: N x C x H x W
filter dimension: G x C_out x C_in x H_filter x W_filter
output dimension: N x C_out x H_out x W_out

I believe this can be done by temporarily treating batch elements as channels and using the groups= argument of Conv2d to do so. However, it does not yield the correct result.

The code below seems to work for G=1 (normal convolution with same filter applied to all batch elements) and G=N (different filter for each element in the batch). However, it does not seem to work for some G in between 1 and N.

N, C, H, W = 128, 3, 28, 28
CC = 64

for groups in (1, 2, N):
    # dummy data
    x = torch.randn(N, C, H, W)
    w = torch.randn(N, CC, C, 7, 7)

    sub_size = N // groups

    x = torch.randn(N, C, H, W)  
    w = torch.randn(groups, CC, C, 7, 7) 

    # desired computation
    out = []
    for i in range(groups):
        sub_x = x[i*sub_size:i*sub_size + sub_size]
        sub_w = w[i]

        sub_out = F.conv2d(sub_x, sub_w)

        out.append(sub_out)
    out = torch.cat(out)

    # first attempt exploiting groups to perform the computation in parallel
    w_fast = w.view(groups*CC, C, 7, 7)
    x_fast = x.view(sub_size, groups*C, H, W)

    out_grouped = F.conv2d(x_fast, w_fast, groups=groups)
    out_grouped = out_grouped.view(N, CC, *out_grouped.shape[-2:])

    # should be the same:
    print((out - out_grouped).abs().max())
>> tensor(0.) # <-- good
>> tensor(91.9094) # <-- bad, G=2 fails.
>> tensor(4.5776e-05) # <-- good

What am I missing here?

Many thanks in advance!

Tycho

Hi Tycho!

This is exactly correct – you may merge the batch and channels dimensions
together and then use conv2d() with its groups argument.

However, in order to interleave the grouped batch elements and the channels
properly, you will need to insert a transpose() into the manipulations you use
to create x_fast. Note x_fastB and out_groupedB in this tweaked version
of the code you posted:

import torch
print (torch.__version__)

_ = torch.manual_seed (2022)

N, C, H, W = 128, 3, 28, 28
CC = 64

for groups in (1, 2, 16, 64, N):   # add a couple more group sizes
    sub_size = N // groups
    
    x = torch.randn(N, C, H, W)  
    w = torch.randn(groups, CC, C, 7, 7) 
    
    # desired computation
    out = []
    for i in range(groups):
        sub_x = x[i*sub_size:i*sub_size + sub_size]
        sub_w = w[i]
        
        sub_out = torch.nn.functional.conv2d (sub_x, sub_w)
        
        out.append(sub_out)
    out = torch.cat(out)
    
    # first attempt exploiting groups to perform the computation in parallel
    w_fast = w.view(groups*CC, C, 7, 7)
    x_fast = x.view(sub_size, groups*C, H, W)
    
    # need transpose() to "interleave" the input groups with one another properly
    x_fastB = x.reshape (groups, sub_size, C, *x.shape[-2:]).transpose (0, 1).reshape (sub_size, groups * C, *x.shape[-2:])
    
    out_grouped = torch.nn.functional.conv2d (x_fast, w_fast, groups = groups)
    out_grouped = out_grouped.view(N, CC, *out_grouped.shape[-2:])
    
    out_groupedB = torch.nn.functional.conv2d (x_fastB, w_fast, groups = groups)
    # and need transpose() to "deinterleave" the output groups
    out_groupedB = out_groupedB.reshape (sub_size, groups, CC, *out_groupedB.shape[-2:]).transpose (0, 1).reshape (N, CC, *out_groupedB.shape[-2:])
    
    # should be the same:
    print ('groups:', groups)
    print ('check out_grouped:', (out - out_grouped).abs().max())
    print ('check out_groupedB:', (out - out_groupedB).abs().max())

Here is its output:

1.13.0
groups: 1
check out_grouped: tensor(0.)
check out_groupedB: tensor(0.)
groups: 2
check out_grouped: tensor(88.9461)
check out_groupedB: tensor(0.)
groups: 16
check out_grouped: tensor(91.0194)
check out_groupedB: tensor(4.5776e-05)
groups: 64
check out_grouped: tensor(92.4197)
check out_groupedB: tensor(3.8147e-05)
groups: 128
check out_grouped: tensor(3.8147e-05)
check out_groupedB: tensor(3.8147e-05)

Best.

K. Frank