Best way to convolve on different channels with a single kernel?

Say I have a 2D signal which is basically 4 channels of 1d signals and each 1d signal is of shape 100. So my input is of shape 100x4. Now I have a single kernel, torch.nn.Conv1d() and I want to apply the same kernel to each of the channels individually. What is the most efficient way to do this?

The method I have come up is to use list, but I feel there should be more elegant way to do the same. My code is

out_list = []
for i in channels:
    # inp is of shape batch_size x channels (4) x vector len (100)
   out = conv(inp[:, [i], :])
    out_list.append(out)
o1 = torch.cat(out_list,0)
o1 = o1.view(num_channels, batch_size, vector_len)
o1 = o1.permute(1,0,2).contiguous()
return o1

Would be glad if there is a better implementation for the same (or if there is any bug in the presented code). Thank you

1 Like

You could repeat the single kernel and use the functional API:

x = Variable(torch.randn(1, in_channels, 100))

# Gaussian kernel
kernel = Variable(torch.FloatTensor([[[0.06136, 0.24477, 0.38774, 0.24477, 0.06136]]]))
weights = kernel.repeat(1, in_channels, 1)
output = F.conv1d(x, weights)

Would this work for your use case?

Will this also work if I want to learn the kernels as well? That is during the backward() method and updating weights with the optimizer, will this ensure that the kernels weights get updated by the average of the gradients in all the channels?

You can initialize it with the same weights for each input channel. The vanilla weight updates will change the weights without the constraint to have the same values.

If you need to backpropagate, the better method would be to change the view of the input channels, so that the channels are stored in the batch dimension, and you only use a single kernel.

This should work:

x = Variable(torch.randn(batch_size, in_channels, 100))

# Gaussian kernel
kernel = Variable(torch.FloatTensor([[[0.06136, 0.24477, 0.38774, 0.24477, 0.06136]]]))
output = F.conv1d(x.view(-1, 1, 100), kernel).view(batch_size, in_channels, 96)
3 Likes

Sweet. This should. work.