I’m fairly new to this and am trying to implement my own version of resnet on the CIFAR-10 dataset.
My issue is when I’m comparing my version (slightly changed to the original implementation) vs the defauly pytorch version, the output of torchsummary is quite different. Mainly the forward/backwards pass size.
I know i’m doing something wrong here… what increases the MB usage of the forward/backwards pass? The parameter number is around the same but the pytorch forward/backward pass is around 6% the size of mine.
What effect does having a high forward/backwards pass have?
Also:
If i increase the input tensor to (1,3,128,128) the difference between the pytorch and my version is WILDLY differnt.
For cifar-10 :
My torch summary:
from torchsummary import summary
summary(net.cuda(), (3, 32, 32))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 32, 32] 192
BatchNorm2d-2 [-1, 64, 32, 32] 128
Conv2d-3 [-1, 64, 32, 32] 1,792
BatchNorm2d-4 [-1, 64, 32, 32] 128
ReLU-5 [-1, 64, 32, 32] 0
ReLU-6 [-1, 64, 32, 32] 0
Conv2d-7 [-1, 64, 32, 32] 36,928
BatchNorm2d-8 [-1, 64, 32, 32] 128
Dropout2d-9 [-1, 64, 32, 32] 0
ReLU-10 [-1, 64, 32, 32] 0
Block-11 [-1, 64, 32, 32] 0
Conv2d-12 [-1, 64, 32, 32] 4,096
BatchNorm2d-13 [-1, 64, 32, 32] 128
Conv2d-14 [-1, 64, 32, 32] 36,928
BatchNorm2d-15 [-1, 64, 32, 32] 128
ReLU-16 [-1, 64, 32, 32] 0
ReLU-17 [-1, 64, 32, 32] 0
Conv2d-18 [-1, 64, 32, 32] 36,928
BatchNorm2d-19 [-1, 64, 32, 32] 128
Dropout2d-20 [-1, 64, 32, 32] 0
ReLU-21 [-1, 64, 32, 32] 0
MaxPool2d-22 [-1, 64, 16, 16] 0
Block-23 [-1, 64, 16, 16] 0
Conv2d-24 [-1, 128, 16, 16] 8,192
BatchNorm2d-25 [-1, 128, 16, 16] 256
Conv2d-26 [-1, 128, 16, 16] 73,856
BatchNorm2d-27 [-1, 128, 16, 16] 256
ReLU-28 [-1, 128, 16, 16] 0
ReLU-29 [-1, 128, 16, 16] 0
Conv2d-30 [-1, 128, 16, 16] 147,584
BatchNorm2d-31 [-1, 128, 16, 16] 256
Dropout2d-32 [-1, 128, 16, 16] 0
ReLU-33 [-1, 128, 16, 16] 0
Block-34 [-1, 128, 16, 16] 0
Conv2d-35 [-1, 128, 16, 16] 16,384
BatchNorm2d-36 [-1, 128, 16, 16] 256
Conv2d-37 [-1, 128, 16, 16] 147,584
BatchNorm2d-38 [-1, 128, 16, 16] 256
ReLU-39 [-1, 128, 16, 16] 0
ReLU-40 [-1, 128, 16, 16] 0
Conv2d-41 [-1, 128, 16, 16] 147,584
BatchNorm2d-42 [-1, 128, 16, 16] 256
Dropout2d-43 [-1, 128, 16, 16] 0
ReLU-44 [-1, 128, 16, 16] 0
MaxPool2d-45 [-1, 128, 8, 8] 0
Block-46 [-1, 128, 8, 8] 0
Conv2d-47 [-1, 256, 8, 8] 32,768
BatchNorm2d-48 [-1, 256, 8, 8] 512
Conv2d-49 [-1, 256, 8, 8] 295,168
BatchNorm2d-50 [-1, 256, 8, 8] 512
ReLU-51 [-1, 256, 8, 8] 0
ReLU-52 [-1, 256, 8, 8] 0
Conv2d-53 [-1, 256, 8, 8] 590,080
BatchNorm2d-54 [-1, 256, 8, 8] 512
Dropout2d-55 [-1, 256, 8, 8] 0
ReLU-56 [-1, 256, 8, 8] 0
Block-57 [-1, 256, 8, 8] 0
Conv2d-58 [-1, 256, 8, 8] 65,536
BatchNorm2d-59 [-1, 256, 8, 8] 512
Conv2d-60 [-1, 256, 8, 8] 590,080
BatchNorm2d-61 [-1, 256, 8, 8] 512
ReLU-62 [-1, 256, 8, 8] 0
ReLU-63 [-1, 256, 8, 8] 0
Conv2d-64 [-1, 256, 8, 8] 590,080
BatchNorm2d-65 [-1, 256, 8, 8] 512
Dropout2d-66 [-1, 256, 8, 8] 0
ReLU-67 [-1, 256, 8, 8] 0
MaxPool2d-68 [-1, 256, 4, 4] 0
Block-69 [-1, 256, 4, 4] 0
Conv2d-70 [-1, 512, 4, 4] 131,072
BatchNorm2d-71 [-1, 512, 4, 4] 1,024
Conv2d-72 [-1, 512, 4, 4] 1,180,160
BatchNorm2d-73 [-1, 512, 4, 4] 1,024
ReLU-74 [-1, 512, 4, 4] 0
ReLU-75 [-1, 512, 4, 4] 0
Conv2d-76 [-1, 512, 4, 4] 2,359,808
BatchNorm2d-77 [-1, 512, 4, 4] 1,024
Dropout2d-78 [-1, 512, 4, 4] 0
ReLU-79 [-1, 512, 4, 4] 0
Block-80 [-1, 512, 4, 4] 0
Conv2d-81 [-1, 512, 4, 4] 262,144
BatchNorm2d-82 [-1, 512, 4, 4] 1,024
Conv2d-83 [-1, 512, 4, 4] 2,359,808
BatchNorm2d-84 [-1, 512, 4, 4] 1,024
ReLU-85 [-1, 512, 4, 4] 0
ReLU-86 [-1, 512, 4, 4] 0
Conv2d-87 [-1, 512, 4, 4] 2,359,808
BatchNorm2d-88 [-1, 512, 4, 4] 1,024
Dropout2d-89 [-1, 512, 4, 4] 0
ReLU-90 [-1, 512, 4, 4] 0
MaxPool2d-91 [-1, 512, 2, 2] 0
Block-92 [-1, 512, 2, 2] 0
AdaptiveAvgPool2d-93 [-1, 512, 2, 2] 0
Flatten-94 [-1, 2048] 0
Linear-95 [-1, 512] 1,049,088
Linear-96 [-1, 10] 5,130
================================================================
Total params: 12,540,298
Trainable params: 12,540,298
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 20.19
Params size (MB): 47.84
Estimated Total Size (MB): 68.04
----------------------------------------------------------------
The original resnet pytorch implementation.
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 16, 16] 9,408
BatchNorm2d-2 [-1, 64, 16, 16] 128
ReLU-3 [-1, 64, 16, 16] 0
MaxPool2d-4 [-1, 64, 8, 8] 0
Conv2d-5 [-1, 64, 8, 8] 36,864
BatchNorm2d-6 [-1, 64, 8, 8] 128
ReLU-7 [-1, 64, 8, 8] 0
Conv2d-8 [-1, 64, 8, 8] 36,864
BatchNorm2d-9 [-1, 64, 8, 8] 128
ReLU-10 [-1, 64, 8, 8] 0
BasicBlock-11 [-1, 64, 8, 8] 0
Conv2d-12 [-1, 64, 8, 8] 36,864
BatchNorm2d-13 [-1, 64, 8, 8] 128
ReLU-14 [-1, 64, 8, 8] 0
Conv2d-15 [-1, 64, 8, 8] 36,864
BatchNorm2d-16 [-1, 64, 8, 8] 128
ReLU-17 [-1, 64, 8, 8] 0
BasicBlock-18 [-1, 64, 8, 8] 0
Conv2d-19 [-1, 128, 4, 4] 73,728
BatchNorm2d-20 [-1, 128, 4, 4] 256
ReLU-21 [-1, 128, 4, 4] 0
Conv2d-22 [-1, 128, 4, 4] 147,456
BatchNorm2d-23 [-1, 128, 4, 4] 256
Conv2d-24 [-1, 128, 4, 4] 8,192
BatchNorm2d-25 [-1, 128, 4, 4] 256
ReLU-26 [-1, 128, 4, 4] 0
BasicBlock-27 [-1, 128, 4, 4] 0
Conv2d-28 [-1, 128, 4, 4] 147,456
BatchNorm2d-29 [-1, 128, 4, 4] 256
ReLU-30 [-1, 128, 4, 4] 0
Conv2d-31 [-1, 128, 4, 4] 147,456
BatchNorm2d-32 [-1, 128, 4, 4] 256
ReLU-33 [-1, 128, 4, 4] 0
BasicBlock-34 [-1, 128, 4, 4] 0
Conv2d-35 [-1, 256, 2, 2] 294,912
BatchNorm2d-36 [-1, 256, 2, 2] 512
ReLU-37 [-1, 256, 2, 2] 0
Conv2d-38 [-1, 256, 2, 2] 589,824
BatchNorm2d-39 [-1, 256, 2, 2] 512
Conv2d-40 [-1, 256, 2, 2] 32,768
BatchNorm2d-41 [-1, 256, 2, 2] 512
ReLU-42 [-1, 256, 2, 2] 0
BasicBlock-43 [-1, 256, 2, 2] 0
Conv2d-44 [-1, 256, 2, 2] 589,824
BatchNorm2d-45 [-1, 256, 2, 2] 512
ReLU-46 [-1, 256, 2, 2] 0
Conv2d-47 [-1, 256, 2, 2] 589,824
BatchNorm2d-48 [-1, 256, 2, 2] 512
ReLU-49 [-1, 256, 2, 2] 0
BasicBlock-50 [-1, 256, 2, 2] 0
Conv2d-51 [-1, 512, 1, 1] 1,179,648
BatchNorm2d-52 [-1, 512, 1, 1] 1,024
ReLU-53 [-1, 512, 1, 1] 0
Conv2d-54 [-1, 512, 1, 1] 2,359,296
BatchNorm2d-55 [-1, 512, 1, 1] 1,024
Conv2d-56 [-1, 512, 1, 1] 131,072
BatchNorm2d-57 [-1, 512, 1, 1] 1,024
ReLU-58 [-1, 512, 1, 1] 0
BasicBlock-59 [-1, 512, 1, 1] 0
Conv2d-60 [-1, 512, 1, 1] 2,359,296
BatchNorm2d-61 [-1, 512, 1, 1] 1,024
ReLU-62 [-1, 512, 1, 1] 0
Conv2d-63 [-1, 512, 1, 1] 2,359,296
BatchNorm2d-64 [-1, 512, 1, 1] 1,024
ReLU-65 [-1, 512, 1, 1] 0
BasicBlock-66 [-1, 512, 1, 1] 0
AdaptiveAvgPool2d-67 [-1, 512, 1, 1] 0
Linear-68 [-1, 1000] 513,000
================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 1.29
Params size (MB): 44.59
Estimated Total Size (MB): 45.90
----------------------------------------------------------------
My code if it helps:
import torch.nn as nn
import torch.nn.functional as F
class Block(nn.Module):
def __init__(self, in_chan, out_chan, pool=True, act="relu",dropout=True):
super(Block, self).__init__()
self.in_chan = in_chan
self.out_chan = out_chan
if act == "relu":
self.activation1 = nn.ReLU()
self.activation2 = nn.ReLU()
elif act == "prelu":
self.activation1 = nn.PReLU()
self.activation2 = nn.PReLU()
elif act == "leak":
self.activation1 = nn.LeakyReLU()
self.activation2 = nn.LeakyReLU()
self.conv_1 = nn.Sequential(nn.Conv2d(in_chan,out_chan,3,padding=1),
nn.BatchNorm2d(out_chan),
self.activation1,
nn.Conv2d(out_chan,out_chan,3,padding=1),
nn.BatchNorm2d(out_chan))
if dropout:
self.conv_1.add_module("Drop",nn.Dropout2d(0.2))
if pool:
self.pooling = nn.MaxPool2d(kernel_size=2,stride=2)
self.scaler = nn.Sequential(nn.Conv2d(in_chan,out_chan,1,stride=1,bias=False),
nn.BatchNorm2d(out_chan))
def forward(self, x):
scaled = self.scaler(x)
#scaled = conv1x1(self.in_chan,self.out_chan,1)
block_out = self.conv_1(x)
if hasattr(self,"pooling"):
return self.pooling(self.activation2(scaled + block_out))
else:
return self.activation2(scaled + block_out)
class Net(nn.Module):
def __init__(self,n_classes=10,in_chans=3,model_shape=[(64,3),(128,3),(256,3)],act="relu",adapt_pool_size=2,dropout=True):
super(Net, self).__init__()
seq_final = nn.Sequential()
chans = 3
for i,blocks in enumerate(model_shape):
seq = nn.Sequential()
if blocks[1] > 1:
seq.add_module(f"Init Block {i}", Block(chans,blocks[0],False,act=act,dropout=dropout))
if blocks[1] > 2:
for x in range(blocks[1]-2):
seq.add_module(f"Inner Block {i}:{x}", Block(blocks[0],blocks[0],False,act=act,dropout=dropout))
seq.add_module(f"Exit Block {i}", Block(blocks[0],blocks[0],True,act=act,dropout=dropout))
else:
seq.add_module(f"Lonely Block {i}", Block(chans,blocks[0],True,act=act,dropout=dropout))
chans = blocks[0]
seq_final.add_module(f"Main Block {i}",seq)
self.conv_block = seq_final
self.avgpool = nn.AdaptiveAvgPool2d((2, 2))
inshape = model_shape[-1][0]*adapt_pool_size*adapt_pool_size
if dropout:
self.linear = nn.Sequential(nn.Flatten(),
nn.Linear(inshape,int(inshape/4)),
nn.Linear(int(inshape/4),n_classes))
else:
self.linear = nn.Sequential(nn.Flatten(),
nn.Dropout2d(0.4),
nn.Linear(inshape,int(inshape/4)),
nn.Linear(int(inshape/4),n_classes))
def forward(self, x):
x = self.conv_block(x)
x = self.avgpool(x)
x = self.linear(x)
return x
net = Net(model_shape=[(64,2),(128,2),(256,2),(512,2)],act="relu",n_classes=10)