I have a tensor with a batch of multi-channel 2D image patches and I want to efficiently apply the same convolutional filters to each image patch. How can I do this?

Let’s say:

B = # batches

N = # patches per batch

C = # input channels per patch

D = # output channels per patch

H = height of patch

W = width of patch

My input tensor `x`

is B x N x C x H x W.

Given a convolution function `conv(x)`

, I want to be output to output B x N x D x H x W.

Typically, these patches are spatially fairly small, on the order of 8x8 pixels or so. However, N can be like 20,000+.

I have tried collapsing the first two dimensions to a BN x C x H x W tensor, but it appears that PyTorch’s convolution implementation is designed such that the convolution operation is parallelized per batch within a for-loop. Thus, I’m just looping over BN batches sequentially. This is quite slow (I’ve tested it) and doesn’t really take advantage of a GPU at all as it’s largely sequential.

The other thing I looked at was a grouped convolution. I collapsed dimensions `1`

and `2`

so I had a B x NC x H x W tensor, and then set `groups=N`

. However this results in a B x D x H x W output, whereas I would be looking for a B x ND x H x W output. If I explicitly provide that multiplier to the number of channels (e.g. `out_channels=N*1024`

or something), however, PyTorch tries to initialize a massive weight matrix for the convolution, which is certainly unnecessary and obviously exceeds the capabilities of my device.

Finally, I considered tiling the patches so that the tensor was B x C x NH x W (or some variation thereof), but this would now require some funky striding or padding to make sure I wasn’t sampling between patches.

*In summary, I am looking for an efficient operation that will apply a D x C x K x K weight matrix to my B x N x C x H x W input tensor and produce a B x N x D x H x W output tensor.* That is, a convolution in groups, but not a grouped convolution. **Is this possible using the existing tools in PyTorch?** Everything is so nicely aligned in memory that is seems like it ought to be. I am open to any necessary permutation of dimensions.