Parameters almost same in Depthwise Separable Convolution (when compared to a standard CNN)?

Hi,

I am new to using PyTorch and really like the Pythonic approach it offers. Currently I’m trying to do a Neural Net performance comaprison between a simple Depthwise Separable Convolutional Neural Net and a standard Convolutional Neural net. The model architecture of both is simple:

The following are done 3 times
Convolution-> Batch Normalization-> ReLU -> Pool
Followed by flattening it to 3 outputs in two steps.

The code is as following:

Standard Conv. Net architecture:

in_features = 304  #in_features for Flatten(linear) layer

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=2, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=4)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)        
        
        self.conv2 = nn.Conv2d(4, 8, kernel_size=2, padding=1)
        self.bn2 = nn.BatchNorm2d(num_features=8)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)        

        self.conv3 = nn.Conv2d(8, 16, kernel_size=2, padding=1)
        self.bn3 = nn.BatchNorm2d(num_features=16)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(2)        

        self.fc1   = nn.Linear(in_features, 36)
        self.relu4 = nn.ReLU()
        self.fc2   = nn.Linear(36, 3)
        
    def forward(self, x):
        out = self.pool1(self.relu1(self.bn1(self.conv1(x))))
        out = self.pool2(self.relu2(self.bn2(self.conv2(out))))
        out = self.pool3(self.relu3(self.bn3(self.conv3(out))))
        
        out = out.view(out.size(0), -1)
        out = self.relu4(self.fc1(out))
        out = self.fc2(out)

        return out

standardCNN = Net()  # defining an instance of our network

Depthwise Separable Conv. Net architecture:

class depthwise_separable_conv(nn.Module):
    def __init__(self):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise1 = nn.Conv2d(1, 1, kernel_size=2, padding=1, groups=1)
        self.pointwise1 = nn.Conv2d(1, 4, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(num_features=4)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)        
        
        self.depthwise2 = nn.Conv2d(4, 4, kernel_size=2, padding=1, groups=4)
        self.pointwise2 = nn.Conv2d(4, 8, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(num_features=8)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        
        self.depthwise3 = nn.Conv2d(8, 8, kernel_size=2, padding=1, groups=8)
        self.pointwise3 = nn.Conv2d(8, 16, kernel_size=1)
        self.bn3 = nn.BatchNorm2d(num_features=16)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(2)        
        
        self.dense1 = nn.Linear(304, 36)
        self.relu4 = nn.ReLU()
        self.dense2 = nn.Linear(36,3)
        
    def forward(self, x):
        out = self.pool1(self.relu1(self.bn1(self.pointwise1(self.depthwise1(x)))))
        out = self.pool2(self.relu2(self.bn2(self.pointwise2(self.depthwise2(out)))))
        out = self.pool3(self.relu3(self.bn3(self.pointwise3(self.depthwise3(out)))))
        
        out = out.view(out.size(0), -1)
        out = self.relu4(self.dense1(out))
        out = self.dense2(out)
        
        return out

dsCNN_model = depthwise_separable_conv()

Now, here’s the important part. As the Standard Convolutions require a large number of parameters as compared to depthwise, I was hoping to get reduced number of parameters for my dsCNN_model. But to my surprise, I got these results.

Parameters:

total_params = 0
for parameter in standardCNN.parameters():
    if parameter.requires_grad:
        total_params += np.prod(parameter.size())
print(total_params)
>>11831

total_params = 0
for parameter in dsCNN_model.parameters():
    blah blah...        
print(total_params)
>>11404

Both having almost the same with a mere reduction of 400 parameters (from ~11800 to ~11400).

I wanted to ask, is it that this is normal behavior of Depthwise model and we do not see that much of a differnce in smaller models? Or is there something wrong with my approach, or in the model architecture that is causing this? I’m actually trying to get a model with least parameters so the inference is possible on a hardware with less computational resources. Also, what other approaches are there in PyTorch’s arsenal that can be used for this purpose.

Any help would be really appreciated… Thanks!

kernel_size K, in_channels M, out_channels N, assume that you have no bias.
naive convolution has KKMN parameters.
separable convolution has KKM + MN parameters.

Let me try to caculate KKMN - (KKM + MN).

nn.Conv2d(1, 4, kernel_size=2, padding=1)

2x2x4x1 - 2x2x4 - 4x1 = -4

nn.Conv2d(4, 8, kernel_size=2, padding=1)

2x2x8x4 - 2x2x8 - 8x4 = 64

nn.Conv2d(8, 16, kernel_size=2, padding=1)

2x2x16x8 - 2x2x16 - 16x8 = 320

Finally… (-4 + 64 + 320) = 380, maybe I’m wrong, but the value is really small…

1 Like

Thanks a bunch for this explanation!! :smile:

A couple of things I would ask…

Correct me if I’m wrong here. So, putting it in mathematical terms. If I want this difference to be large,

KKMN - (KKM + MN) = KKMN - M(KK + N) = M(KKN - KK - N)

I should have a large M, and should try other values for K and N. This is the correct approach, right?

And second, although doing this might give me a large difference, but possibly will take me away from my main objective of having least number of params to begin with. If that is correct, can you shed some light on how can I tweak PyTorch model architecture to have a model with fewer params (because I had a model on Keras with same structure and had about 3000 paramters)? Or is having fewer parameters even a good metric for a model to utilize less resources during inference?

Some networks, such as MobileNet and ResNet, have N = 512 or N = 256.
Assume that K = 2 and N = M/2, we have
KKMN - (KKM + MN) = 4MM/2 - (4M + MM/2) = 1.5MM - 4M.
What a big number it is!

1 Like

Right, that makes a lot of sense…
Thanks again!

Also: the majority of your parameters are in this one layer:

self.dense1 = nn.Linear(304, 36)

(304 in_channels) * (36 out_channels) = 10944

So, you might consider averagepooling to 1x1xChannels right before dense1. And, make dense1 a convolution with 304 in_channels, 36 out_channels, kernel_size of 1, and groups=32 or 64 or whatever number you like. That ought to bring your parameter count way down.

1 Like

WHOA!! That really did work…
I have my parameters down from ~11k to ~2k and I can manually control them as well. That was a big help @solvingPuzzles:slightly_smiling_face:

But, the Convolution layer had to be given the number of input channels after AvgPool which happened to be 16 not 16*19=304 (or perhaps did I misunderstood something here?) but it is working now. Just for reference, here are the changes I made:

        ...

        self.bn3 = nn.BatchNorm2d(num_features=16)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.AvgPool2d(2)
        
        self.conv = nn.Conv2d(16, x, kernel_size=1, groups=16)
        
        self.relu4 = nn.ReLU()
        self.dense2 = nn.Linear(x*19,3)

        ...

I’ll just have to evaluate and compare their performance soon enough…

1 Like