Is there a way to adapt this approach so that rather than performing convolutions using different kernels per channel, it performs convolutions on all channels for a tensor in a batch but the kernel for each tensor in the batch changes? So each batch item undergoes a convolution with a unique prespecified kernel.
e.g. something like this:
import torch
batch_size = 8
channels = 10
img_size = 30
kernel_size = 3
batch = torch.rand((batch_size,channels,img_size,img_size))
# Make a unique kernel for each batch member but the kernel is convolved
# with every channel
weights = torch.rand((batch_size,1,kernel_size,kernel_size)).repeat(1,channels,1,1)
print(weights.shape)
conv = torch.nn.Conv2d(channels,channels,kernel_size,padding=4,bias=False)
with torch.no_grad():
conv.weight = torch.nn.Parameter(weights,requires_grad=False)
output = conv(batch)
print(output.shape)
Edit: This has been solved using for loops or groups here