Applying different convolutions on different batch element

Hey there,

I want to achieve a neural network where, on a given layer, the layer operation applied is not the same on all instances of a batch. For instance, if my network is a MLP, I want each instance’s feature vector to be multiplied by a different matrix. In that case, I can get that result by using torch.matmul and extending the weight tensor across the batch dimension.

When I try to do the same with a CNN, I can’t. nn.functionnal.conv2d does not accept a weight tensor that has an extra dimension for batch. To get around this, I compute each convolution sequentially and stack them. While it works, it does not use my GPU efficiently and is thus really slow.

My question regarding this is the following: is there something in the pytorch API that I missed and that I could use to achieve instance-dependent convolution, similar to how matmul functions? If not, is there any particular reason why?

Thanks!

You could try to move the batch dimension to the channel dimension and use a grouped convolution, where e.g. each kernel is only applied on a single input channel.
After the convolution is done you could reshape the output to the original shape.

Clever solution. It worked and sped up my training 4 or 5x.

That being said, there is still a significant speed gap between this and a standard CNN. Is it to be expected that grouped convolutions be slower than standard ones?

Thank you for your help!

@ptrblck I actually need to do the same but the channel dimensions are larger than one. Is there any efficient way to solve this (i.e deep feature correlation parallelized across batches)?
Currently I am using the following slow sequential solution:

n, c, h, w = image1_features.shape
for i in range(n):
    corr = F.conv2d(image1_features[i,:,:,:].unsqueeze(0), image2_features[i,:,:,:].unsqueeze(0), padding=(h, w))

If you want to use separate filters for each input channel, you can use a grouped convolution via the groups argument.

@ptrblck That’s not quite what I want. I want to correlate N feature tensors of C channels each with N other feature tensors of C channels each. You can think of this as “template matching” two different set of features. I have added more detail to the code I posted above to help clarify this below:

images1_features = some_cnn(images1) #image1_features.shape = 32,100,8,8
images2_features = some_cnn(images2) #image2_features.shape = 32,100,8,8
#I now have 32 sets of two tensors representing the deep features of pictures in images1 and pictures in images2. 
#The pictures in images2 are shifted versions of pictures in images1 and therefore I would like to cross correlate corresponding feature batches to compute spatial alignment.
#Following is a sequential way of doing this but I would like the processing to be parallel across batches.
n, c, h, w = image1_features.shape
peak_locations = []
for i in range(n):
    corr = F.conv2d(image1_features[i,:,:,:].unsqueeze(0), image2_features[i,:,:,:].unsqueeze(0), padding=(h, w))
    peak_location = peak_interpolate(corr)
    peak_locations.append(peak_location)
peak_locations = torch.cat(peak_locations,dim=0) #peak_locations.shape = 32,2 (32 images with x,y vector representing displacement)