Torchsummary forward/backward pass MB extremly high. (my own Resnet implementation)

I’m fairly new to this and am trying to implement my own version of resnet on the CIFAR-10 dataset.

My issue is when I’m comparing my version (slightly changed to the original implementation) vs the defauly pytorch version, the output of torchsummary is quite different. Mainly the forward/backwards pass size.

I know i’m doing something wrong here… what increases the MB usage of the forward/backwards pass? The parameter number is around the same but the pytorch forward/backward pass is around 6% the size of mine.

What effect does having a high forward/backwards pass have?

Also:

If i increase the input tensor to (1,3,128,128) the difference between the pytorch and my version is WILDLY differnt.

For cifar-10 :

My torch summary:

from torchsummary import summary
​
summary(net.cuda(), (3, 32, 32))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 32, 32]             192
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3           [-1, 64, 32, 32]           1,792
       BatchNorm2d-4           [-1, 64, 32, 32]             128
              ReLU-5           [-1, 64, 32, 32]               0
              ReLU-6           [-1, 64, 32, 32]               0
            Conv2d-7           [-1, 64, 32, 32]          36,928
       BatchNorm2d-8           [-1, 64, 32, 32]             128
         Dropout2d-9           [-1, 64, 32, 32]               0
             ReLU-10           [-1, 64, 32, 32]               0
            Block-11           [-1, 64, 32, 32]               0
           Conv2d-12           [-1, 64, 32, 32]           4,096
      BatchNorm2d-13           [-1, 64, 32, 32]             128
           Conv2d-14           [-1, 64, 32, 32]          36,928
      BatchNorm2d-15           [-1, 64, 32, 32]             128
             ReLU-16           [-1, 64, 32, 32]               0
             ReLU-17           [-1, 64, 32, 32]               0
           Conv2d-18           [-1, 64, 32, 32]          36,928
      BatchNorm2d-19           [-1, 64, 32, 32]             128
        Dropout2d-20           [-1, 64, 32, 32]               0
             ReLU-21           [-1, 64, 32, 32]               0
        MaxPool2d-22           [-1, 64, 16, 16]               0
            Block-23           [-1, 64, 16, 16]               0
           Conv2d-24          [-1, 128, 16, 16]           8,192
      BatchNorm2d-25          [-1, 128, 16, 16]             256
           Conv2d-26          [-1, 128, 16, 16]          73,856
      BatchNorm2d-27          [-1, 128, 16, 16]             256
             ReLU-28          [-1, 128, 16, 16]               0
             ReLU-29          [-1, 128, 16, 16]               0
           Conv2d-30          [-1, 128, 16, 16]         147,584
      BatchNorm2d-31          [-1, 128, 16, 16]             256
        Dropout2d-32          [-1, 128, 16, 16]               0
             ReLU-33          [-1, 128, 16, 16]               0
            Block-34          [-1, 128, 16, 16]               0
           Conv2d-35          [-1, 128, 16, 16]          16,384
      BatchNorm2d-36          [-1, 128, 16, 16]             256
           Conv2d-37          [-1, 128, 16, 16]         147,584
      BatchNorm2d-38          [-1, 128, 16, 16]             256
             ReLU-39          [-1, 128, 16, 16]               0
             ReLU-40          [-1, 128, 16, 16]               0
           Conv2d-41          [-1, 128, 16, 16]         147,584
      BatchNorm2d-42          [-1, 128, 16, 16]             256
        Dropout2d-43          [-1, 128, 16, 16]               0
             ReLU-44          [-1, 128, 16, 16]               0
        MaxPool2d-45            [-1, 128, 8, 8]               0
            Block-46            [-1, 128, 8, 8]               0
           Conv2d-47            [-1, 256, 8, 8]          32,768
      BatchNorm2d-48            [-1, 256, 8, 8]             512
           Conv2d-49            [-1, 256, 8, 8]         295,168
      BatchNorm2d-50            [-1, 256, 8, 8]             512
             ReLU-51            [-1, 256, 8, 8]               0
             ReLU-52            [-1, 256, 8, 8]               0
           Conv2d-53            [-1, 256, 8, 8]         590,080
      BatchNorm2d-54            [-1, 256, 8, 8]             512
        Dropout2d-55            [-1, 256, 8, 8]               0
             ReLU-56            [-1, 256, 8, 8]               0
            Block-57            [-1, 256, 8, 8]               0
           Conv2d-58            [-1, 256, 8, 8]          65,536
      BatchNorm2d-59            [-1, 256, 8, 8]             512
           Conv2d-60            [-1, 256, 8, 8]         590,080
      BatchNorm2d-61            [-1, 256, 8, 8]             512
             ReLU-62            [-1, 256, 8, 8]               0
             ReLU-63            [-1, 256, 8, 8]               0
           Conv2d-64            [-1, 256, 8, 8]         590,080
      BatchNorm2d-65            [-1, 256, 8, 8]             512
        Dropout2d-66            [-1, 256, 8, 8]               0
             ReLU-67            [-1, 256, 8, 8]               0
        MaxPool2d-68            [-1, 256, 4, 4]               0
            Block-69            [-1, 256, 4, 4]               0
           Conv2d-70            [-1, 512, 4, 4]         131,072
      BatchNorm2d-71            [-1, 512, 4, 4]           1,024
           Conv2d-72            [-1, 512, 4, 4]       1,180,160
      BatchNorm2d-73            [-1, 512, 4, 4]           1,024
             ReLU-74            [-1, 512, 4, 4]               0
             ReLU-75            [-1, 512, 4, 4]               0
           Conv2d-76            [-1, 512, 4, 4]       2,359,808
      BatchNorm2d-77            [-1, 512, 4, 4]           1,024
        Dropout2d-78            [-1, 512, 4, 4]               0
             ReLU-79            [-1, 512, 4, 4]               0
            Block-80            [-1, 512, 4, 4]               0
           Conv2d-81            [-1, 512, 4, 4]         262,144
      BatchNorm2d-82            [-1, 512, 4, 4]           1,024
           Conv2d-83            [-1, 512, 4, 4]       2,359,808
      BatchNorm2d-84            [-1, 512, 4, 4]           1,024
             ReLU-85            [-1, 512, 4, 4]               0
             ReLU-86            [-1, 512, 4, 4]               0
           Conv2d-87            [-1, 512, 4, 4]       2,359,808
      BatchNorm2d-88            [-1, 512, 4, 4]           1,024
        Dropout2d-89            [-1, 512, 4, 4]               0
             ReLU-90            [-1, 512, 4, 4]               0
        MaxPool2d-91            [-1, 512, 2, 2]               0
            Block-92            [-1, 512, 2, 2]               0
AdaptiveAvgPool2d-93            [-1, 512, 2, 2]               0
          Flatten-94                 [-1, 2048]               0
           Linear-95                  [-1, 512]       1,049,088
           Linear-96                   [-1, 10]           5,130
================================================================
Total params: 12,540,298
Trainable params: 12,540,298
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 20.19
Params size (MB): 47.84
Estimated Total Size (MB): 68.04
----------------------------------------------------------------

The original resnet pytorch implementation.

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          36,864
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,864
       BatchNorm2d-9             [-1, 64, 8, 8]             128
             ReLU-10             [-1, 64, 8, 8]               0
       BasicBlock-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,864
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 64, 8, 8]               0
           Conv2d-15             [-1, 64, 8, 8]          36,864
      BatchNorm2d-16             [-1, 64, 8, 8]             128
             ReLU-17             [-1, 64, 8, 8]               0
       BasicBlock-18             [-1, 64, 8, 8]               0
           Conv2d-19            [-1, 128, 4, 4]          73,728
      BatchNorm2d-20            [-1, 128, 4, 4]             256
             ReLU-21            [-1, 128, 4, 4]               0
           Conv2d-22            [-1, 128, 4, 4]         147,456
      BatchNorm2d-23            [-1, 128, 4, 4]             256
           Conv2d-24            [-1, 128, 4, 4]           8,192
      BatchNorm2d-25            [-1, 128, 4, 4]             256
             ReLU-26            [-1, 128, 4, 4]               0
       BasicBlock-27            [-1, 128, 4, 4]               0
           Conv2d-28            [-1, 128, 4, 4]         147,456
      BatchNorm2d-29            [-1, 128, 4, 4]             256
             ReLU-30            [-1, 128, 4, 4]               0
           Conv2d-31            [-1, 128, 4, 4]         147,456
      BatchNorm2d-32            [-1, 128, 4, 4]             256
             ReLU-33            [-1, 128, 4, 4]               0
       BasicBlock-34            [-1, 128, 4, 4]               0
           Conv2d-35            [-1, 256, 2, 2]         294,912
      BatchNorm2d-36            [-1, 256, 2, 2]             512
             ReLU-37            [-1, 256, 2, 2]               0
           Conv2d-38            [-1, 256, 2, 2]         589,824
      BatchNorm2d-39            [-1, 256, 2, 2]             512
           Conv2d-40            [-1, 256, 2, 2]          32,768
      BatchNorm2d-41            [-1, 256, 2, 2]             512
             ReLU-42            [-1, 256, 2, 2]               0
       BasicBlock-43            [-1, 256, 2, 2]               0
           Conv2d-44            [-1, 256, 2, 2]         589,824
      BatchNorm2d-45            [-1, 256, 2, 2]             512
             ReLU-46            [-1, 256, 2, 2]               0
           Conv2d-47            [-1, 256, 2, 2]         589,824
      BatchNorm2d-48            [-1, 256, 2, 2]             512
             ReLU-49            [-1, 256, 2, 2]               0
       BasicBlock-50            [-1, 256, 2, 2]               0
           Conv2d-51            [-1, 512, 1, 1]       1,179,648
      BatchNorm2d-52            [-1, 512, 1, 1]           1,024
             ReLU-53            [-1, 512, 1, 1]               0
           Conv2d-54            [-1, 512, 1, 1]       2,359,296
      BatchNorm2d-55            [-1, 512, 1, 1]           1,024
           Conv2d-56            [-1, 512, 1, 1]         131,072
      BatchNorm2d-57            [-1, 512, 1, 1]           1,024
             ReLU-58            [-1, 512, 1, 1]               0
       BasicBlock-59            [-1, 512, 1, 1]               0
           Conv2d-60            [-1, 512, 1, 1]       2,359,296
      BatchNorm2d-61            [-1, 512, 1, 1]           1,024
             ReLU-62            [-1, 512, 1, 1]               0
           Conv2d-63            [-1, 512, 1, 1]       2,359,296
      BatchNorm2d-64            [-1, 512, 1, 1]           1,024
             ReLU-65            [-1, 512, 1, 1]               0
       BasicBlock-66            [-1, 512, 1, 1]               0
AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
           Linear-68                 [-1, 1000]         513,000
================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 1.29
Params size (MB): 44.59
Estimated Total Size (MB): 45.90
----------------------------------------------------------------

My code if it helps:

import torch.nn as nn
import torch.nn.functional as F

class Block(nn.Module):
    
    def __init__(self, in_chan, out_chan, pool=True, act="relu",dropout=True):        
      
        super(Block, self).__init__()
        
        self.in_chan = in_chan 
        self.out_chan = out_chan
        
        if act == "relu":
            self.activation1 = nn.ReLU()
            self.activation2 = nn.ReLU()
        elif act == "prelu":
            self.activation1 = nn.PReLU()
            self.activation2 = nn.PReLU()
        elif act == "leak":
            self.activation1 = nn.LeakyReLU()
            self.activation2 = nn.LeakyReLU()
            
                
        self.conv_1 = nn.Sequential(nn.Conv2d(in_chan,out_chan,3,padding=1),
                                          nn.BatchNorm2d(out_chan),
                                          self.activation1,
                                          nn.Conv2d(out_chan,out_chan,3,padding=1),
                                          nn.BatchNorm2d(out_chan))
        if dropout:
            self.conv_1.add_module("Drop",nn.Dropout2d(0.2))
            
        if pool:
            self.pooling = nn.MaxPool2d(kernel_size=2,stride=2)

                
        self.scaler = nn.Sequential(nn.Conv2d(in_chan,out_chan,1,stride=1,bias=False),
                                    nn.BatchNorm2d(out_chan))
        
    def forward(self, x):
        
        scaled = self.scaler(x)   
        
        #scaled = conv1x1(self.in_chan,self.out_chan,1)
        
        block_out = self.conv_1(x)

        if hasattr(self,"pooling"):
            return self.pooling(self.activation2(scaled + block_out))
        else:
            return self.activation2(scaled + block_out)
        
        

class Net(nn.Module):
    
    def __init__(self,n_classes=10,in_chans=3,model_shape=[(64,3),(128,3),(256,3)],act="relu",adapt_pool_size=2,dropout=True):
        
        super(Net, self).__init__()
        
        seq_final = nn.Sequential()
        chans = 3 

        for i,blocks in enumerate(model_shape):

            seq = nn.Sequential()

            if blocks[1] > 1:
                seq.add_module(f"Init Block {i}", Block(chans,blocks[0],False,act=act,dropout=dropout))

                if blocks[1] > 2:

                    for x in range(blocks[1]-2):
                        seq.add_module(f"Inner Block {i}:{x}", Block(blocks[0],blocks[0],False,act=act,dropout=dropout))

                seq.add_module(f"Exit Block {i}", Block(blocks[0],blocks[0],True,act=act,dropout=dropout))
            else:
                seq.add_module(f"Lonely Block {i}", Block(chans,blocks[0],True,act=act,dropout=dropout))


            chans = blocks[0]

            seq_final.add_module(f"Main Block {i}",seq)

        self.conv_block = seq_final        
      
        self.avgpool = nn.AdaptiveAvgPool2d((2, 2))
        
        
        inshape = model_shape[-1][0]*adapt_pool_size*adapt_pool_size
        
        if dropout:
            
            self.linear = nn.Sequential(nn.Flatten(),
                                        nn.Linear(inshape,int(inshape/4)),
                                        nn.Linear(int(inshape/4),n_classes))

        else:
            
             self.linear = nn.Sequential(nn.Flatten(),
                                         nn.Dropout2d(0.4),
                                         nn.Linear(inshape,int(inshape/4)),
                                         nn.Linear(int(inshape/4),n_classes))           
    
    def forward(self, x):                                  
                                
        x = self.conv_block(x)
        x = self.avgpool(x)
        x = self.linear(x)
        return x

net = Net(model_shape=[(64,2),(128,2),(256,2),(512,2)],act="relu",n_classes=10)

1 Like

Hey I was running into a similar memory usage issue with my custom model. Were you able to figure out the cause?

Maybe it’s a bit late to answer.
If you compare the spatial resolution of corresponding layers between both the networks, you may notice that your custom network works on a higher resolution of features, whereas the official version of ResNet downsamples rapidly to reduce computation in later layers.

Hi,

From what I have gathered from my own experience, the forward/backward pass size is affected mainy by the kernel sizes of the conv layer within the network accompanied by the initial input size. i.e following pytorchs format of convnets BCHW, if the HW size is larger than the kernel size mainly results in an increase in size of the forward/backward pass. If the HW is smaller than the kernel size results in a smaller backward/forward pass size.