How to apply different kernels to each example in a batch when using convolution?

F.conv2d only supports applying the same kernel to all examples in a batch.
However, I want to apply different kernels to each example. How can I do this?

The most naive approach seems the code below:

def parallel_conv2d(inputs, filters, stride=1, padding=1):
  batch_size = inputs.size(0)

  output_slices = [F.conv2d(inputs[i:i+1], filters[i], bias=None, stride=stride, padding=padding).squeeze(0)
                   for i in range(batch_size)]
  return torch.stack(output_slices, dim=0)

Is there any better and efficient implementation or APIs that I can utilize?

You could use a grouped convolution, which would use a single filter for a group of input channel, and in the case of a depthwise convolution an own filter for each input channel.
Have a look at the docs for more information.

Hi, Thanks for your reply.
However, I think there is a misunderstanding. I apologize that my question was not clear enough.

My question means that I want to apply different kernels to each image, not different channels.
More specifically, the arguments of the parallel_conv2d are:
inputs: mini_batch x in_channel x iH x iW tensor - batch of input “images”
filters: mini_batch x out_channel x in_channel x KH x kW tensor - a batch of convolutional kernels

What I want to do is applying filters[i] to inputs[i].

Thanks for the update and I clearly misunderstood the use case.
I think if the kernel shapes are different, you would need to use a loop and concatenate the output afterwards, as the filters cannot be stored directly in a single tensor.

However, if the kernels have all the same shape, the grouped conv approach might still work.
Here is a small example using convolutions with in_channels=3, out_channels=15 for a batch size of 10:

# Setup
N, C, H, W = 10, 3, 24, 24
x = torch.randn(N, C, H, W)

# Create filterset for each sample
weights = []
for _ in range(N):
    weight = nn.Parameter(torch.randn(15, 3, 5, 5))
    weights.append(weight)

# Apply manually
outputs = []
for idx in range(N):
    input = x[idx:idx+1]
    weight = weights[idx]
    output = F.conv2d(input, weight, stride=1, padding=2)
    outputs.append(output)

outputs = torch.stack(outputs)
outputs = outputs.squeeze(1) # remove fake batch dimension
print(outputs.shape)
> torch.Size([10, 15, 24, 24])

# Use grouped approach
weights = torch.stack(weights)
weights = weights.view(-1, 3, 5, 5)
print(weights.shape)
> torch.Size([150, 3, 5, 5])
# move batch dim into channels
x = x.view(1, -1, H, W)
print(x.shape)
> torch.Size([1, 30, 24, 24])
# Apply grouped conv
outputs_grouped = F.conv2d(x, weights, stride=1, padding=2, groups=N)
outputs_grouped = outputs_grouped.view(N, 15, 24, 24)

# Compare
print((outputs - outputs_grouped).abs().max())
tensor(1.3351e-05, grad_fn=<MaxBackward1>)

If this approach could work, I would recommend to profile both approaches and see, if my suggestion is faster for your workload.

3 Likes

Hi, thanks for your reply.
I think it’s the best trick that I can choose.

This code was very helpful because I want to do same thing by using conv3d.

However when I actually ran a below program, I wonder this code may be not work in case of using conv3d.

import torch
import torch.nn as nn
import torch.nn.functional as F

#Setup
N, T, C, H, W = 10, 24, 3, 24, 24
x = torch.stack([torch.full(size=(C,T,H,W), fill_value=float(n)) for n in range(N)])
for x_elem in x:
    print(x_elem.abs().max())

# Create filterset for each sample
OC = 15
weights = []
for _ in range(N):
    weight = nn.Parameter(torch.full(size=(OC,C,5,5,5), fill_value=float(_)))
    weights.append(weight)
    print(weight.abs().max())

# Apply manually
outputs = []
for idx in range(N):
    input = x[idx:idx+1]
    weight = weights[idx]
    output = F.conv3d(input, weight, stride=1, padding=2)
    outputs.append(output)


outputs = torch.stack(outputs)
outputs = outputs.squeeze(1) # remove fake batch dimension
# outputs = torch.cat(outputs, dim=0)
print(outputs.shape)

for output in outputs:
    print(output.abs().max())

# Use grouped approach
weights = torch.stack(weights)
weights = weights.view(-1, C, 5, 5, 5)
print(weights.shape)

# move batch dim into channels
x = x.view(1, -1, T, H, W)
print(x.shape)

# Apply grouped conv
outputs_grouped = F.conv3d(x, weights, stride=1, padding=2, groups=N)
outputs_grouped = outputs_grouped.view(10, 15, 24, 24, 24)
print(outputs_grouped.shape)
# Compare
print((outputs - outputs_grouped).abs().max())

for output_grouped in outputs_grouped:
    print(outputs_grouped.abs().max())

outputs

tensor(0.)
tensor(1.)
tensor(2.)
tensor(3.)
tensor(4.)
tensor(5.)
tensor(6.)
tensor(7.)
tensor(8.)
tensor(9.)
tensor(0., grad_fn=<MaxBackward1>)
tensor(1., grad_fn=<MaxBackward1>)
tensor(2., grad_fn=<MaxBackward1>)
tensor(3., grad_fn=<MaxBackward1>)
tensor(4., grad_fn=<MaxBackward1>)
tensor(5., grad_fn=<MaxBackward1>)
tensor(6., grad_fn=<MaxBackward1>)
tensor(7., grad_fn=<MaxBackward1>)
tensor(8., grad_fn=<MaxBackward1>)
tensor(9., grad_fn=<MaxBackward1>)
torch.Size([10, 15, 24, 24, 24])

for output in outputs:
    print(output.abs().max())
>tensor(0., grad_fn=<MaxBackward1>)
>tensor(375., grad_fn=<MaxBackward1>)
>tensor(1500., grad_fn=<MaxBackward1>)
>tensor(3375., grad_fn=<MaxBackward1>)
>tensor(6000., grad_fn=<MaxBackward1>)
>tensor(9375., grad_fn=<MaxBackward1>)
>tensor(13500., grad_fn=<MaxBackward1>)
>tensor(18375., grad_fn=<MaxBackward1>)
>tensor(24000., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)

print(weights.shape)
>torch.Size([150, 3, 5, 5, 5])

print(x.shape)
>torch.Size([1, 30, 24, 24, 24])

print(outputs_grouped.shape)
>torch.Size([10, 15, 24, 24, 24])

print((outputs - outputs_grouped).abs().max())
>tensor(0., grad_fn=<MaxBackward1>)

for output_grouped in outputs_grouped:
    print(outputs_grouped.abs().max())
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)

Displayed max values of outputs in this code gradually increase with increasing torch.full value.
On the other hand, all outputs_grouped 's max values are same.

I think it might be caused by outputs_grouped = outputs_grouped.view(10, 15, 24, 24, 24) . But I couldn’t this solution.

I would be grateful if you could tell me about your experience.

Your code works fine and shows a zero absolute difference between the manual and grouped approach.
The issue with the constant print values is raised, because you are using print(outputs_grouped.abs().max()) instead of print(output_grouped.abs().max()) in the last loop.

I made a very stupid mistake.

I confirm this code correctly work. Thank you very much.