Separable Convolutions in PyTorch

Hi,

I’m still new to pytorch, and I was trying to implement the MobileNets (Howard et al) in Pytorch. In the paper the idea of a separable convolution is introduced. Tensorflow has a tf.slim that contains a separable convolution operation, I wanted to know if a similar operation is available in pytorch as well.

If not, do you plan to support it in the future?

Thanks

One way I see is to perform separable convolutions by using 1xN and Nx1 convolutions. It might not be as efficient as a single kernel call, but should be ok.

Hi there

I gather you mean something along the lines like these:

x = Variable(torch.randn(1,32,32,3))
model = nn.Sequential(
                      nn.Conv2d(1,N,kernel=3,stride=2)
                      nn.Conv2d(N,1,kernel=1,stride=1)
                 )
res = model(x)

Is it possible to create our own “layer” to implement this?

You have to apply a different 2d filter on each of the M input’s channels I_m, let say you obtain M filtered maps F_m. Then each of the N output’s channels O_n is a different linear combination of these F_m.

I don’t see how to avoid a loop over the input’s channels, applying M times nn.Conv2d(1,1). Then applying the N linear combinations may be done by a single matrix multiplication.

Is this not just grouped convolution? If it’s anything like this paper: https://arxiv.org/pdf/1610.02357.pdf
then I think it is.

Example for mobile-net:

conv = nn.Sequential(
               nn.Conv2d(1, N, kernel=(3, 3), stride=(2, 2)),
               nn.Conv2d(N, N, kernel=(3, 3), stride=(2, 2), groups=N),
               nn.Conv2d(N, 1, kernel=(1, 1), stride=(1, 1))
       )
2 Likes

I’m not familiar with the concept of groups yet, however in this paper they use batchnorm after a depthwise and pointwise convolution. Alexis solution seems more plausible currently, but I’ll post what I find as soon as possible

1 Like

I tried groups=num_output, result is even slower than the complete convolution. Can someone implement separable convs properly?

1 Like

I filed a feature request here with more details.