Thanks for the update and I clearly misunderstood the use case.
I think if the kernel shapes are different, you would need to use a loop and concatenate the output afterwards, as the filters cannot be stored directly in a single tensor.
However, if the kernels have all the same shape, the grouped conv approach might still work.
Here is a small example using convolutions with in_channels=3
, out_channels=15
for a batch size of 10:
# Setup
N, C, H, W = 10, 3, 24, 24
x = torch.randn(N, C, H, W)
# Create filterset for each sample
weights = []
for _ in range(N):
weight = nn.Parameter(torch.randn(15, 3, 5, 5))
weights.append(weight)
# Apply manually
outputs = []
for idx in range(N):
input = x[idx:idx+1]
weight = weights[idx]
output = F.conv2d(input, weight, stride=1, padding=2)
outputs.append(output)
outputs = torch.stack(outputs)
outputs = outputs.squeeze(1) # remove fake batch dimension
print(outputs.shape)
> torch.Size([10, 15, 24, 24])
# Use grouped approach
weights = torch.stack(weights)
weights = weights.view(-1, 3, 5, 5)
print(weights.shape)
> torch.Size([150, 3, 5, 5])
# move batch dim into channels
x = x.view(1, -1, H, W)
print(x.shape)
> torch.Size([1, 30, 24, 24])
# Apply grouped conv
outputs_grouped = F.conv2d(x, weights, stride=1, padding=2, groups=N)
outputs_grouped = outputs_grouped.view(N, 15, 24, 24)
# Compare
print((outputs - outputs_grouped).abs().max())
tensor(1.3351e-05, grad_fn=<MaxBackward1>)
If this approach could work, I would recommend to profile both approaches and see, if my suggestion is faster for your workload.