Applying conv2d filter to all channels seperately, is my solution efficient?


#1

Hi,

For a given input of size (batch, channels, width, height) I would like to apply a 2-strided convolution with a single fixed 2D-filter to each channel of each batch, resulting in an output of size (batch, channels, width/2, height/2).

Using the group parameter of nn.functional.conv2d I came up with this solution:

I would like to apply the filter

fil = torch.tensor([
    [0.5,  0.5],
    [-0.5, -0.5]])

to my input

X = torch.rand(32, 2048, 128, 128).

To this end, I add two dummy dimensions (out_channels and in_channels/groups) to my filter and expand the 0th dimension of my filter tensor to be equal to the number of channels of my input (in this case 2048). I’m keeping the 1st dimension unchanged since in_channels/groups will be equal to 1 by using groups=in_channels in nn.functional.conv2d.

fil_tensor = fil[None, None, :, :].expand(X.size(1), -1, -1, -1)

This works:

res = torch.nn.functional.conv2d(
    X, fil_tensor, stride=2, groups=X.size(1))

but I’m worried about the step where I expanded my filter, basically creating 2048 copies of redundant information. Is there a better way to do this?

Thanks!


#2

I think it would be faster to reshape your input, so that your channels are stacked in the batch dimension.
[batch_size, channels, h, w] would become [batch_size * channels, 1, h, w].
Then you could use a conv layer with in_channels=1 and out_channels=1 and reshape the output again.

batch_size = 10
channels = 3
h, w = 24, 24
x = torch.randn(batch_size, channels, h, w)

conv = nn.Conv2d(1, 1, 4, 2, 1)
output = conv(x.view(-1, 1, h, w)).view(batch_size, channels, h//2, w//2)
print(output.shape)

#3

This is exactly what I was looking for! Appreciate it!


(Alban D) #4

Just for fun,
I think you can also do it with:

  • Average pooling with kernel [1, 2] and stride [1, 2].
  • Flip the sign of every other row
  • Sum every pair of rows.

I don’t think that’s going to be more efficient than @ptrblck 's solution though …


#5

Quite an interesting approach. Haven’t thought about it and wanted to try it out.
Not “optimized” code, but the error seems to show the results are equal (up to float precision):

batch_size = 10
channels = 3
h, w = 24, 24
x = torch.randn(batch_size, channels, h, w)

# View approach
conv = nn.Conv2d(1, 1, 2, 2, bias=False)
with torch.no_grad():
    conv.weight = nn.Parameter(torch.tensor([[[[0.5, 0.5],
                                               [-0.5, -0.5]]]]))
output = conv(x.view(-1, 1, h, w)).view(batch_size, channels, h//2, w//2)

# Pool approach
pool = nn.AvgPool2d((1, 2), (1, 2))
output_ = pool(x)

output_[:, :, 1::2, :] = output_[:, :, 1::2, :] * -1
output_ = torch.cat([output_[:, :, a:a+1, :] + output_[:, :, a+1:a+2, :] for a in range(0, h, 2)], dim=2)

print(torch.sum(output.abs() - output_.abs()))

(Alban D) #6

Well advanced indexing is not my thing, but it works well indeed (might even be more efficient that the conv :open_mouth: :

# Pool approach without advanced indexing
pool = nn.AvgPool2d((1, 2), (1, 2))
output_2 = pool(x)

output_2 = output_2.view(batch_size, channels, h//2, 2, w//2)
output_2.select(3, 1).mul_(-1)
output_2 = output_2.sum(3)[0]

#7

Awesome! Thanks for this approach. :slight_smile: