How to apply different kernels to each example in a batch when using convolution?

Thanks for the update and I clearly misunderstood the use case.
I think if the kernel shapes are different, you would need to use a loop and concatenate the output afterwards, as the filters cannot be stored directly in a single tensor.

However, if the kernels have all the same shape, the grouped conv approach might still work.
Here is a small example using convolutions with in_channels=3, out_channels=15 for a batch size of 10:

# Setup
N, C, H, W = 10, 3, 24, 24
x = torch.randn(N, C, H, W)

# Create filterset for each sample
weights = []
for _ in range(N):
    weight = nn.Parameter(torch.randn(15, 3, 5, 5))
    weights.append(weight)

# Apply manually
outputs = []
for idx in range(N):
    input = x[idx:idx+1]
    weight = weights[idx]
    output = F.conv2d(input, weight, stride=1, padding=2)
    outputs.append(output)

outputs = torch.stack(outputs)
outputs = outputs.squeeze(1) # remove fake batch dimension
print(outputs.shape)
> torch.Size([10, 15, 24, 24])

# Use grouped approach
weights = torch.stack(weights)
weights = weights.view(-1, 3, 5, 5)
print(weights.shape)
> torch.Size([150, 3, 5, 5])
# move batch dim into channels
x = x.view(1, -1, H, W)
print(x.shape)
> torch.Size([1, 30, 24, 24])
# Apply grouped conv
outputs_grouped = F.conv2d(x, weights, stride=1, padding=2, groups=N)
outputs_grouped = outputs_grouped.view(N, 15, 24, 24)

# Compare
print((outputs - outputs_grouped).abs().max())
tensor(1.3351e-05, grad_fn=<MaxBackward1>)

If this approach could work, I would recommend to profile both approaches and see, if my suggestion is faster for your workload.

9 Likes