Batch F.conv2d with unique filters on minibatch dimension

Hi,
I was wondering if anyone could help me solve this problem. Say I have a minibatch of inputs:

x = torch.randn(N, C_in, iH, iW)

Along with that, I have N unique sets of filters:

W = torch.randn(N, C_out, C_in, kH, kW)

What I’d like is something like:

y = F.batch_conv2d(x, W)
# let's say iH=5, iW=5, kH=1, kW=1, C_out=3, N=10
assert(y.shape == (10, 3, 5, 5))

I’ve tried searching the forum for something similar, but the closest I’ve seen is someone wanting to partition filters along the input channels (i.e. depthwise convolutions, so they ended up using groups). I need the split to be in the minibatch dimension.
Another thing I’ve tried is to use F.conv3d with the minibatch being moved into the iT dimension, but this dimension isn’t preserved in the output since it wasn’t designed for this use case.

Any ideas?

I think I’ve come up with a solution by stacking along the channel dimension and misusing groups:

>>> x1.shape
(1, 3, 5, 5)
>>> x = torch.cat([x1,x2,x3,x4], dim=1)
>>> x.shape
(1, 12, 5, 5)
>>> w = torch.randn(12, 3, 1, 1)
>>> y = F.conv2d(x, w, groups=4)
>>> y.shape
(1, 12, 5, 5)
>>> y.view(4,3,5,5)

I’m not 100% sure this is the right thing to do, so I’d still appreciate any thoughts.