Unfold tensor and apply conv2d to each tile

I have unfolded tensor and got the resulting dimension like this: [batch, tilesCount, chanCount, h, w]. I would like to apply a different conv2d operation for each tile (a list of size tilesCount with conv2d). Can this be done on a single tensor or do I have to split it to list and apply convolutions separately in a loop? I have been looking at groups, but I got slightly lost if it can be used or not.
As an input, I need: [batch, tilesCount, chanCountAfterConv, h, w] (same w and h of input and output is via padding)

Hi Martin!

Yes, you can use groups to give each tile its own convolution kernel
and bias.

To do this, reshape() (or possibly view()) your input tensor so that
you combine your tilesCount dimension with your chanCount dimension.
Using groups = tilesCount in effect splits your Conv2d up into multiple
Conv2ds. Then reshape() the output of your convolution to put your tiles
back into their own dimension.

Here is a script that illustrates this approach using the shapes you posted:

import torch
print (torch.__version__)

_ = torch.manual_seed (2022)

batch = 2
tilesCount = 5
chanCount = 3
h = 8
w = 8

chanCountAfterConv = 2

input = torch.randn (batch, tilesCount, chanCount, h, w)
print ('input.shape:', input.shape)

input[0, 3] = 0.0   # zero out tile-3 of batch-0
input[1, 1] = 0.0   # zero out tile-1 of batch-1
print ('print out channel-1 of a zero-tile')
print (input[0, 3, 1])

t = input.reshape (batch, tilesCount * chanCount, h, w)   # combine tiles into channels dimension
print ('t.shape:', t.shape)

conv = torch.nn.Conv2d (   # each tile is its own group with its own kernel
    in_channels = tilesCount * chanCount,
    out_channels = tilesCount * chanCountAfterConv,
    kernel_size = 3,
    padding = 1,
    groups = tilesCount,   # each group is a tile
    bias = True
)
print ('conv.weight.shape:', conv.weight.shape)

u = conv (t)   # apply grouped convolution
print ('u.shape:', u.shape)

output = u.reshape (batch, tilesCount, chanCountAfterConv, h, w)   # separate tiles from channels dimensions
print ('output.shape:', output.shape)

# verify that post-convolution zero-tiles are pure bias
br = conv.bias.reshape (tilesCount, chanCountAfterConv)
print ('br.shape:', br.shape)   # five tiles, each with two channels
print ('check that output for zero-tiles is pure bias')
print (torch.equal (output[0, 3], br[3].unsqueeze (-1).unsqueeze (-1).expand (chanCountAfterConv, h,w)))
print (torch.equal (output[1, 1], br[1].unsqueeze (-1).unsqueeze (-1).expand (chanCountAfterConv, h,w)))

And here is its output:

1.13.0
input.shape: torch.Size([2, 5, 3, 8, 8])
print out channel-1 of a zero-tile
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
t.shape: torch.Size([2, 15, 8, 8])
conv.weight.shape: torch.Size([10, 3, 3, 3])
u.shape: torch.Size([2, 10, 8, 8])
output.shape: torch.Size([2, 5, 2, 8, 8])
br.shape: torch.Size([5, 2])
check that output for zero-tiles is pure bias
True
True

Best.

K. Frank