I have a tensor with a batch of multi-channel 2D image patches and I want to efficiently apply the same convolutional filters to each image patch. How can I do this?
Let’s say:
B = # batches
N = # patches per batch
C = # input channels per patch
D = # output channels per patch
H = height of patch
W = width of patch
My input tensor x
is B x N x C x H x W.
Given a convolution function conv(x)
, I want to be output to output B x N x D x H x W.
Typically, these patches are spatially fairly small, on the order of 8x8 pixels or so. However, N can be like 20,000+.
I have tried collapsing the first two dimensions to a BN x C x H x W tensor, but it appears that PyTorch’s convolution implementation is designed such that the convolution operation is parallelized per batch within a for-loop. Thus, I’m just looping over BN batches sequentially. This is quite slow (I’ve tested it) and doesn’t really take advantage of a GPU at all as it’s largely sequential.
The other thing I looked at was a grouped convolution. I collapsed dimensions 1
and 2
so I had a B x NC x H x W tensor, and then set groups=N
. However this results in a B x D x H x W output, whereas I would be looking for a B x ND x H x W output. If I explicitly provide that multiplier to the number of channels (e.g. out_channels=N*1024
or something), however, PyTorch tries to initialize a massive weight matrix for the convolution, which is certainly unnecessary and obviously exceeds the capabilities of my device.
Finally, I considered tiling the patches so that the tensor was B x C x NH x W (or some variation thereof), but this would now require some funky striding or padding to make sure I wasn’t sampling between patches.
In summary, I am looking for an efficient operation that will apply a D x C x K x K weight matrix to my B x N x C x H x W input tensor and produce a B x N x D x H x W output tensor. That is, a convolution in groups, but not a grouped convolution. Is this possible using the existing tools in PyTorch? Everything is so nicely aligned in memory that is seems like it ought to be. I am open to any necessary permutation of dimensions.