# Applying conv2d filter to all channels seperately, is my solution efficient?

Hi,

For a given input of size (batch, channels, width, height) I would like to apply a 2-strided convolution with a single fixed 2D-filter to each channel of each batch, resulting in an output of size (batch, channels, width/2, height/2).

Using the group parameter of nn.functional.conv2d I came up with this solution:

I would like to apply the filter

``````fil = torch.tensor([
[0.5,  0.5],
[-0.5, -0.5]])
``````

to my input

`X = torch.rand(32, 2048, 128, 128)`.

To this end, I add two dummy dimensions (out_channels and in_channels/groups) to my filter and expand the 0th dimension of my filter tensor to be equal to the number of channels of my input (in this case 2048). Iâ€™m keeping the 1st dimension unchanged since in_channels/groups will be equal to 1 by using groups=in_channels in nn.functional.conv2d.

`fil_tensor = fil[None, None, :, :].expand(X.size(1), -1, -1, -1)`

This works:

``````res = torch.nn.functional.conv2d(
X, fil_tensor, stride=2, groups=X.size(1))
``````

but Iâ€™m worried about the step where I expanded my filter, basically creating 2048 copies of redundant information. Is there a better way to do this?

Thanks!

3 Likes

I think it would be faster to reshape your input, so that your channels are stacked in the batch dimension.
`[batch_size, channels, h, w]` would become `[batch_size * channels, 1, h, w]`.
Then you could use a conv layer with `in_channels=1` and `out_channels=1` and reshape the output again.

``````batch_size = 10
channels = 3
h, w = 24, 24
x = torch.randn(batch_size, channels, h, w)

conv = nn.Conv2d(1, 1, 4, 2, 1)
output = conv(x.view(-1, 1, h, w)).view(batch_size, channels, h//2, w//2)
print(output.shape)
``````
8 Likes

This is exactly what I was looking for! Appreciate it!

Just for fun,
I think you can also do it with:

• Average pooling with kernel [1, 2] and stride [1, 2].
• Flip the sign of every other row
• Sum every pair of rows.

I donâ€™t think thatâ€™s going to be more efficient than @ptrblck 's solution though â€¦

3 Likes

Quite an interesting approach. Havenâ€™t thought about it and wanted to try it out.
Not â€śoptimizedâ€ť code, but the error seems to show the results are equal (up to float precision):

``````batch_size = 10
channels = 3
h, w = 24, 24
x = torch.randn(batch_size, channels, h, w)

# View approach
conv = nn.Conv2d(1, 1, 2, 2, bias=False)
conv.weight = nn.Parameter(torch.tensor([[[[0.5, 0.5],
[-0.5, -0.5]]]]))
output = conv(x.view(-1, 1, h, w)).view(batch_size, channels, h//2, w//2)

# Pool approach
pool = nn.AvgPool2d((1, 2), (1, 2))
output_ = pool(x)

output_[:, :, 1::2, :] = output_[:, :, 1::2, :] * -1
output_ = torch.cat([output_[:, :, a:a+1, :] + output_[:, :, a+1:a+2, :] for a in range(0, h, 2)], dim=2)

print(torch.sum(output.abs() - output_.abs()))
``````
3 Likes

Well advanced indexing is not my thing, but it works well indeed (might even be more efficient that the conv :

``````# Pool approach without advanced indexing
pool = nn.AvgPool2d((1, 2), (1, 2))
output_2 = pool(x)

output_2 = output_2.view(batch_size, channels, h//2, 2, w//2)
output_2.select(3, 1).mul_(-1)
output_2 = output_2.sum(3)[0]
``````
3 Likes

Awesome! Thanks for this approach.

1 Like