Applying different convolutions on different batch element

Hey there,

I want to achieve a neural network where, on a given layer, the layer operation applied is not the same on all instances of a batch. For instance, if my network is a MLP, I want each instance’s feature vector to be multiplied by a different matrix. In that case, I can get that result by using torch.matmul and extending the weight tensor across the batch dimension.

When I try to do the same with a CNN, I can’t. nn.functionnal.conv2d does not accept a weight tensor that has an extra dimension for batch. To get around this, I compute each convolution sequentially and stack them. While it works, it does not use my GPU efficiently and is thus really slow.

My question regarding this is the following: is there something in the pytorch API that I missed and that I could use to achieve instance-dependent convolution, similar to how matmul functions? If not, is there any particular reason why?


You could try to move the batch dimension to the channel dimension and use a grouped convolution, where e.g. each kernel is only applied on a single input channel.
After the convolution is done you could reshape the output to the original shape.

Clever solution. It worked and sped up my training 4 or 5x.

That being said, there is still a significant speed gap between this and a standard CNN. Is it to be expected that grouped convolutions be slower than standard ones?

Thank you for your help!