is there similar pytorch function as model.summary() as keras?
repr(model)
gives something fairly close.
or easier to remember:
print(model)
gives the same result as repr(model)
Also, if you just want the number of parameters as opposed to the whole model, you can do something like: sum([param.nelement() for param in model.parameters()])
Yes, you can get exact Keras representation, using this code.
Example for VGG16
from torchvision import models
from summary import summary
vgg = models.vgg16()
summary(vgg, (3, 224, 224))
----------------------------------------------------------------
Layer (type) Output Shpae Param #
================================================================
Conv2d-1 [-1, 64, 224, 224] 1792
ReLU-2 [-1, 64, 224, 224] 0
Conv2d-3 [-1, 64, 224, 224] 36928
ReLU-4 [-1, 64, 224, 224] 0
MaxPool2d-5 [-1, 64, 112, 112] 0
Conv2d-6 [-1, 128, 112, 112] 73856
ReLU-7 [-1, 128, 112, 112] 0
Conv2d-8 [-1, 128, 112, 112] 147584
ReLU-9 [-1, 128, 112, 112] 0
MaxPool2d-10 [-1, 128, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 295168
ReLU-12 [-1, 256, 56, 56] 0
Conv2d-13 [-1, 256, 56, 56] 590080
ReLU-14 [-1, 256, 56, 56] 0
Conv2d-15 [-1, 256, 56, 56] 590080
ReLU-16 [-1, 256, 56, 56] 0
MaxPool2d-17 [-1, 256, 28, 28] 0
Conv2d-18 [-1, 512, 28, 28] 1180160
ReLU-19 [-1, 512, 28, 28] 0
Conv2d-20 [-1, 512, 28, 28] 2359808
ReLU-21 [-1, 512, 28, 28] 0
Conv2d-22 [-1, 512, 28, 28] 2359808
ReLU-23 [-1, 512, 28, 28] 0
MaxPool2d-24 [-1, 512, 14, 14] 0
Conv2d-25 [-1, 512, 14, 14] 2359808
ReLU-26 [-1, 512, 14, 14] 0
Conv2d-27 [-1, 512, 14, 14] 2359808
ReLU-28 [-1, 512, 14, 14] 0
Conv2d-29 [-1, 512, 14, 14] 2359808
ReLU-30 [-1, 512, 14, 14] 0
MaxPool2d-31 [-1, 512, 7, 7] 0
Linear-32 [-1, 4096] 102764544
ReLU-33 [-1, 4096] 0
Dropout-34 [-1, 4096] 0
Linear-35 [-1, 4096] 16781312
ReLU-36 [-1, 4096] 0
Dropout-37 [-1, 4096] 0
Linear-38 [-1, 1000] 4097000
================================================================
Total params: 138357544
Trainable params: 138357544
Non-trainable params: 0
----------------------------------------------------------------
Actually, there’s a difference between keras model.summary() and print(model) in pytorch. print(model in pytorch only print the layers defined in the init function of the class but not the model architecture defined in forward function. Keras model.summary() actually prints the model architecture with input and output shape along with trainable and non trainable parameters.
I haven’t found anything like that in PyTorch. I end up writing bunch of print statements in forward function to determine the input and output shape.
Is the suggestion from @sksq96 not sufficient for the shape information?
If you check the repo, that summary is only printing what’s in init function not the actual forward function where you will apply batch normalization, relu, maxpooling, global max pooling like stuff.
VGG model is printing because its implemented that way, meaning ReLu layer is defined in init function instead of Functional relu. I faced huge problem when implementing UNet, and that’s just one example.
Check this from the repo
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = Net().to(device)
summary(model, (1, 28, 28))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 10, 24, 24] 260
Conv2d-2 [-1, 20, 8, 8] 5,020
Dropout2d-3 [-1, 20, 8, 8] 0
Linear-4 [-1, 50] 16,050
Linear-5 [-1, 10] 510
================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 0.08
Estimated Total Size (MB): 0.15
----------------------------------------------------------------
This summary absolutely doesn’t make any sense.
The summary would still give you all parameters, no?
If you are interested in printing the graph, have a look at e.g. this topic.
I just want a easy function call to print the model summary the way Keras do. PyTorch should have added that. I should not be doing all kind of tricks just to see my model summary with input and output shapes of every layer. I love working in PyTorch that’s why I am looking for that type of function that would make model development easy.
pytorch-summary
should yield the same output as Keras does.
What are you missing?
pytorch-summary can’t handle things like lists of tensors in forward(). Pytorch should definitely have a summary() function which can handle any operations done inside the model.
you can try this library pytorch-model-summary. I rewrote to support scenario you described, adding other options as well
This worked for me
!pip install torchsummary
from torchsummary import summary
summary(your_model, input_size=(channels, H, W))
Really love your work ! works perfectly. thank you.
For a given input shape, you can use the torchinfo
(formerly torchsummary
) package:
Torchinfo provides information complementary to what is provided by
print(your_model)
in PyTorch, similar to Tensorflow’smodel.summary()
…
Example:
from torchinfo import summary
model = ConvNet()
batch_size = 16
summary(model, input_size=(batch_size, 1, 28, 28)
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
├─Conv2d (conv1): 1-1 [5, 10, 24, 24] 260
├─Conv2d (conv2): 1-2 [5, 20, 8, 8] 5,020
├─Dropout2d (conv2_drop): 1-3 [5, 20, 8, 8] --
├─Linear (fc1): 1-4 [5, 50] 16,050
├─Linear (fc2): 1-5 [5, 10] 510
==========================================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 7.69
==========================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 0.91
Params size (MB): 0.09
Estimated Total Size (MB): 1.05
==========================================================================================
Note: Unlike Keras, PyTorch has a dynamic computational graph which can adapt to any compatible input shape across multiple calls e.g. any sufficiently large image size (for a fully convolutional network).
As such, it cannot present an inherent set of input/output shapes for each layer, as these are input-dependent, and why in the above package you must specify the input dimensions.
The one I use is pytorchsummary
It also returns the number of trainable parameters as tuples, which is quite useful in some cases.
from pytorchsummary import summary
from torchvision import models
model = models.efficientnet_b0(False)
summary(input_size=(3,224,224),model=model)
Output:
Layer Output Shape Kernal Shape #params #(weights + bias) requires_grad
------------------------------------------------------------------------------------------------------------------------------------------------------
Conv2d-1 [1, 32, 112, 112] [32, 3, 3, 3] 864 (864+0) True
BatchNorm2d-2 [1, 32, 112, 112] [32] 64 (32 + 32) True True
SiLU-3 [1, 32, 112, 112]
Conv2d-4 [1, 32, 112, 112] [32, 1, 3, 3] 288 (288+0) True
BatchNorm2d-5 [1, 32, 112, 112] [32] 64 (32 + 32) True True
SiLU-6 [1, 32, 112, 112]
AdaptiveAvgPool2d-7 [1, 32, 1, 1]
Conv2d-8 [1, 8, 1, 1] [8, 32, 1, 1] 264 (256 + 8) True True
SiLU-9 [1, 8, 1, 1]
Conv2d-10 [1, 32, 1, 1] [32, 8, 1, 1] 288 (256 + 32) True True
Sigmoid-11 [1, 32, 1, 1]
Conv2d-12 [1, 16, 112, 112] [16, 32, 1, 1] 512 (512+0) True
BatchNorm2d-13 [1, 16, 112, 112] [16] 32 (16 + 16) True True
Conv2d-14 [1, 96, 112, 112] [96, 16, 1, 1] 1536 (1536+0) True
BatchNorm2d-15 [1, 96, 112, 112] [96] 192 (96 + 96) True True
SiLU-16 [1, 96, 112, 112]
Conv2d-17 [1, 96, 56, 56] [96, 1, 3, 3] 864 (864+0) True
BatchNorm2d-18 [1, 96, 56, 56] [96] 192 (96 + 96) True True
SiLU-19 [1, 96, 56, 56]
AdaptiveAvgPool2d-20 [1, 96, 1, 1]
Conv2d-21 [1, 4, 1, 1] [4, 96, 1, 1] 388 (384 + 4) True True
SiLU-22 [1, 4, 1, 1]
Conv2d-23 [1, 96, 1, 1] [96, 4, 1, 1] 480 (384 + 96) True True
Sigmoid-24 [1, 96, 1, 1]
Conv2d-25 [1, 24, 56, 56] [24, 96, 1, 1] 2304 (2304+0) True
BatchNorm2d-26 [1, 24, 56, 56] [24] 48 (24 + 24) True True
Conv2d-27 [1, 144, 56, 56] [144, 24, 1, 1] 3456 (3456+0) True
BatchNorm2d-28 [1, 144, 56, 56] [144] 288 (144 + 144) True True
SiLU-29 [1, 144, 56, 56]
Conv2d-30 [1, 144, 56, 56] [144, 1, 3, 3] 1296 (1296+0) True
BatchNorm2d-31 [1, 144, 56, 56] [144] 288 (144 + 144) True True
SiLU-32 [1, 144, 56, 56]
AdaptiveAvgPool2d-33 [1, 144, 1, 1]
Conv2d-34 [1, 6, 1, 1] [6, 144, 1, 1] 870 (864 + 6) True True
SiLU-35 [1, 6, 1, 1]
Conv2d-36 [1, 144, 1, 1] [144, 6, 1, 1] 1008 (864 + 144) True True
Sigmoid-37 [1, 144, 1, 1]
Conv2d-38 [1, 24, 56, 56] [24, 144, 1, 1] 3456 (3456+0) True
BatchNorm2d-39 [1, 24, 56, 56] [24] 48 (24 + 24) True True
Conv2d-40 [1, 144, 56, 56] [144, 24, 1, 1] 3456 (3456+0) True
BatchNorm2d-41 [1, 144, 56, 56] [144] 288 (144 + 144) True True
SiLU-42 [1, 144, 56, 56]
Conv2d-43 [1, 144, 28, 28] [144, 1, 5, 5] 3600 (3600+0) True
BatchNorm2d-44 [1, 144, 28, 28] [144] 288 (144 + 144) True True
SiLU-45 [1, 144, 28, 28]
AdaptiveAvgPool2d-46 [1, 144, 1, 1]
Conv2d-47 [1, 6, 1, 1] [6, 144, 1, 1] 870 (864 + 6) True True
SiLU-48 [1, 6, 1, 1]
Conv2d-49 [1, 144, 1, 1] [144, 6, 1, 1] 1008 (864 + 144) True True
Sigmoid-50 [1, 144, 1, 1]
Conv2d-51 [1, 40, 28, 28] [40, 144, 1, 1] 5760 (5760+0) True
BatchNorm2d-52 [1, 40, 28, 28] [40] 80 (40 + 40) True True
Conv2d-53 [1, 240, 28, 28] [240, 40, 1, 1] 9600 (9600+0) True
BatchNorm2d-54 [1, 240, 28, 28] [240] 480 (240 + 240) True True
SiLU-55 [1, 240, 28, 28]
Conv2d-56 [1, 240, 28, 28] [240, 1, 5, 5] 6000 (6000+0) True
BatchNorm2d-57 [1, 240, 28, 28] [240] 480 (240 + 240) True True
SiLU-58 [1, 240, 28, 28]
AdaptiveAvgPool2d-59 [1, 240, 1, 1]
Conv2d-60 [1, 10, 1, 1] [10, 240, 1, 1] 2410 (2400 + 10) True True
SiLU-61 [1, 10, 1, 1]
Conv2d-62 [1, 240, 1, 1] [240, 10, 1, 1] 2640 (2400 + 240) True True
Sigmoid-63 [1, 240, 1, 1]
Conv2d-64 [1, 40, 28, 28] [40, 240, 1, 1] 9600 (9600+0) True
BatchNorm2d-65 [1, 40, 28, 28] [40] 80 (40 + 40) True True
Conv2d-66 [1, 240, 28, 28] [240, 40, 1, 1] 9600 (9600+0) True
BatchNorm2d-67 [1, 240, 28, 28] [240] 480 (240 + 240) True True
SiLU-68 [1, 240, 28, 28]
Conv2d-69 [1, 240, 14, 14] [240, 1, 3, 3] 2160 (2160+0) True
BatchNorm2d-70 [1, 240, 14, 14] [240] 480 (240 + 240) True True
SiLU-71 [1, 240, 14, 14]
AdaptiveAvgPool2d-72 [1, 240, 1, 1]
Conv2d-73 [1, 10, 1, 1] [10, 240, 1, 1] 2410 (2400 + 10) True True
SiLU-74 [1, 10, 1, 1]
Conv2d-75 [1, 240, 1, 1] [240, 10, 1, 1] 2640 (2400 + 240) True True
Sigmoid-76 [1, 240, 1, 1]
Conv2d-77 [1, 80, 14, 14] [80, 240, 1, 1] 19200 (19200+0) True
BatchNorm2d-78 [1, 80, 14, 14] [80] 160 (80 + 80) True True
Conv2d-79 [1, 480, 14, 14] [480, 80, 1, 1] 38400 (38400+0) True
BatchNorm2d-80 [1, 480, 14, 14] [480] 960 (480 + 480) True True
SiLU-81 [1, 480, 14, 14]
Conv2d-82 [1, 480, 14, 14] [480, 1, 3, 3] 4320 (4320+0) True
BatchNorm2d-83 [1, 480, 14, 14] [480] 960 (480 + 480) True True
SiLU-84 [1, 480, 14, 14]
AdaptiveAvgPool2d-85 [1, 480, 1, 1]
Conv2d-86 [1, 20, 1, 1] [20, 480, 1, 1] 9620 (9600 + 20) True True
SiLU-87 [1, 20, 1, 1]
Conv2d-88 [1, 480, 1, 1] [480, 20, 1, 1] 10080 (9600 + 480) True True
Sigmoid-89 [1, 480, 1, 1]
Conv2d-90 [1, 80, 14, 14] [80, 480, 1, 1] 38400 (38400+0) True
BatchNorm2d-91 [1, 80, 14, 14] [80] 160 (80 + 80) True True
Conv2d-92 [1, 480, 14, 14] [480, 80, 1, 1] 38400 (38400+0) True
BatchNorm2d-93 [1, 480, 14, 14] [480] 960 (480 + 480) True True
SiLU-94 [1, 480, 14, 14]
Conv2d-95 [1, 480, 14, 14] [480, 1, 3, 3] 4320 (4320+0) True
BatchNorm2d-96 [1, 480, 14, 14] [480] 960 (480 + 480) True True
SiLU-97 [1, 480, 14, 14]
AdaptiveAvgPool2d-98 [1, 480, 1, 1]
Conv2d-99 [1, 20, 1, 1] [20, 480, 1, 1] 9620 (9600 + 20) True True
SiLU-100 [1, 20, 1, 1]
Conv2d-101 [1, 480, 1, 1] [480, 20, 1, 1] 10080 (9600 + 480) True True
Sigmoid-102 [1, 480, 1, 1]
Conv2d-103 [1, 80, 14, 14] [80, 480, 1, 1] 38400 (38400+0) True
BatchNorm2d-104 [1, 80, 14, 14] [80] 160 (80 + 80) True True
Conv2d-105 [1, 480, 14, 14] [480, 80, 1, 1] 38400 (38400+0) True
BatchNorm2d-106 [1, 480, 14, 14] [480] 960 (480 + 480) True True
SiLU-107 [1, 480, 14, 14]
Conv2d-108 [1, 480, 14, 14] [480, 1, 5, 5] 12000 (12000+0) True
BatchNorm2d-109 [1, 480, 14, 14] [480] 960 (480 + 480) True True
SiLU-110 [1, 480, 14, 14]
AdaptiveAvgPool2d-111 [1, 480, 1, 1]
Conv2d-112 [1, 20, 1, 1] [20, 480, 1, 1] 9620 (9600 + 20) True True
SiLU-113 [1, 20, 1, 1]
Conv2d-114 [1, 480, 1, 1] [480, 20, 1, 1] 10080 (9600 + 480) True True
Sigmoid-115 [1, 480, 1, 1]
Conv2d-116 [1, 112, 14, 14] [112, 480, 1, 1] 53760 (53760+0) True
BatchNorm2d-117 [1, 112, 14, 14] [112] 224 (112 + 112) True True
Conv2d-118 [1, 672, 14, 14] [672, 112, 1, 1] 75264 (75264+0) True
BatchNorm2d-119 [1, 672, 14, 14] [672] 1344 (672 + 672) True True
SiLU-120 [1, 672, 14, 14]
Conv2d-121 [1, 672, 14, 14] [672, 1, 5, 5] 16800 (16800+0) True
BatchNorm2d-122 [1, 672, 14, 14] [672] 1344 (672 + 672) True True
SiLU-123 [1, 672, 14, 14]
AdaptiveAvgPool2d-124 [1, 672, 1, 1]
Conv2d-125 [1, 28, 1, 1] [28, 672, 1, 1] 18844 (18816 + 28) True True
SiLU-126 [1, 28, 1, 1]
Conv2d-127 [1, 672, 1, 1] [672, 28, 1, 1] 19488 (18816 + 672) True True
Sigmoid-128 [1, 672, 1, 1]
Conv2d-129 [1, 112, 14, 14] [112, 672, 1, 1] 75264 (75264+0) True
BatchNorm2d-130 [1, 112, 14, 14] [112] 224 (112 + 112) True True
Conv2d-131 [1, 672, 14, 14] [672, 112, 1, 1] 75264 (75264+0) True
BatchNorm2d-132 [1, 672, 14, 14] [672] 1344 (672 + 672) True True
SiLU-133 [1, 672, 14, 14]
Conv2d-134 [1, 672, 14, 14] [672, 1, 5, 5] 16800 (16800+0) True
BatchNorm2d-135 [1, 672, 14, 14] [672] 1344 (672 + 672) True True
SiLU-136 [1, 672, 14, 14]
AdaptiveAvgPool2d-137 [1, 672, 1, 1]
Conv2d-138 [1, 28, 1, 1] [28, 672, 1, 1] 18844 (18816 + 28) True True
SiLU-139 [1, 28, 1, 1]
Conv2d-140 [1, 672, 1, 1] [672, 28, 1, 1] 19488 (18816 + 672) True True
Sigmoid-141 [1, 672, 1, 1]
Conv2d-142 [1, 112, 14, 14] [112, 672, 1, 1] 75264 (75264+0) True
BatchNorm2d-143 [1, 112, 14, 14] [112] 224 (112 + 112) True True
Conv2d-144 [1, 672, 14, 14] [672, 112, 1, 1] 75264 (75264+0) True
BatchNorm2d-145 [1, 672, 14, 14] [672] 1344 (672 + 672) True True
SiLU-146 [1, 672, 14, 14]
Conv2d-147 [1, 672, 7, 7] [672, 1, 5, 5] 16800 (16800+0) True
BatchNorm2d-148 [1, 672, 7, 7] [672] 1344 (672 + 672) True True
SiLU-149 [1, 672, 7, 7]
AdaptiveAvgPool2d-150 [1, 672, 1, 1]
Conv2d-151 [1, 28, 1, 1] [28, 672, 1, 1] 18844 (18816 + 28) True True
SiLU-152 [1, 28, 1, 1]
Conv2d-153 [1, 672, 1, 1] [672, 28, 1, 1] 19488 (18816 + 672) True True
Sigmoid-154 [1, 672, 1, 1]
Conv2d-155 [1, 192, 7, 7] [192, 672, 1, 1] 129024 (129024+0) True
BatchNorm2d-156 [1, 192, 7, 7] [192] 384 (192 + 192) True True
Conv2d-157 [1, 1152, 7, 7] [1152, 192, 1, 1] 221184 (221184+0) True
BatchNorm2d-158 [1, 1152, 7, 7] [1152] 2304 (1152 + 1152) True True
SiLU-159 [1, 1152, 7, 7]
Conv2d-160 [1, 1152, 7, 7] [1152, 1, 5, 5] 28800 (28800+0) True
BatchNorm2d-161 [1, 1152, 7, 7] [1152] 2304 (1152 + 1152) True True
SiLU-162 [1, 1152, 7, 7]
AdaptiveAvgPool2d-163 [1, 1152, 1, 1]
Conv2d-164 [1, 48, 1, 1] [48, 1152, 1, 1] 55344 (55296 + 48) True True
SiLU-165 [1, 48, 1, 1]
Conv2d-166 [1, 1152, 1, 1] [1152, 48, 1, 1] 56448 (55296 + 1152) True True
Sigmoid-167 [1, 1152, 1, 1]
Conv2d-168 [1, 192, 7, 7] [192, 1152, 1, 1] 221184 (221184+0) True
BatchNorm2d-169 [1, 192, 7, 7] [192] 384 (192 + 192) True True
Conv2d-170 [1, 1152, 7, 7] [1152, 192, 1, 1] 221184 (221184+0) True
BatchNorm2d-171 [1, 1152, 7, 7] [1152] 2304 (1152 + 1152) True True
SiLU-172 [1, 1152, 7, 7]
Conv2d-173 [1, 1152, 7, 7] [1152, 1, 5, 5] 28800 (28800+0) True
BatchNorm2d-174 [1, 1152, 7, 7] [1152] 2304 (1152 + 1152) True True
SiLU-175 [1, 1152, 7, 7]
AdaptiveAvgPool2d-176 [1, 1152, 1, 1]
Conv2d-177 [1, 48, 1, 1] [48, 1152, 1, 1] 55344 (55296 + 48) True True
SiLU-178 [1, 48, 1, 1]
Conv2d-179 [1, 1152, 1, 1] [1152, 48, 1, 1] 56448 (55296 + 1152) True True
Sigmoid-180 [1, 1152, 1, 1]
Conv2d-181 [1, 192, 7, 7] [192, 1152, 1, 1] 221184 (221184+0) True
BatchNorm2d-182 [1, 192, 7, 7] [192] 384 (192 + 192) True True
Conv2d-183 [1, 1152, 7, 7] [1152, 192, 1, 1] 221184 (221184+0) True
BatchNorm2d-184 [1, 1152, 7, 7] [1152] 2304 (1152 + 1152) True True
SiLU-185 [1, 1152, 7, 7]
Conv2d-186 [1, 1152, 7, 7] [1152, 1, 5, 5] 28800 (28800+0) True
BatchNorm2d-187 [1, 1152, 7, 7] [1152] 2304 (1152 + 1152) True True
SiLU-188 [1, 1152, 7, 7]
AdaptiveAvgPool2d-189 [1, 1152, 1, 1]
Conv2d-190 [1, 48, 1, 1] [48, 1152, 1, 1] 55344 (55296 + 48) True True
SiLU-191 [1, 48, 1, 1]
Conv2d-192 [1, 1152, 1, 1] [1152, 48, 1, 1] 56448 (55296 + 1152) True True
Sigmoid-193 [1, 1152, 1, 1]
Conv2d-194 [1, 192, 7, 7] [192, 1152, 1, 1] 221184 (221184+0) True
BatchNorm2d-195 [1, 192, 7, 7] [192] 384 (192 + 192) True True
Conv2d-196 [1, 1152, 7, 7] [1152, 192, 1, 1] 221184 (221184+0) True
BatchNorm2d-197 [1, 1152, 7, 7] [1152] 2304 (1152 + 1152) True True
SiLU-198 [1, 1152, 7, 7]
Conv2d-199 [1, 1152, 7, 7] [1152, 1, 3, 3] 10368 (10368+0) True
BatchNorm2d-200 [1, 1152, 7, 7] [1152] 2304 (1152 + 1152) True True
SiLU-201 [1, 1152, 7, 7]
AdaptiveAvgPool2d-202 [1, 1152, 1, 1]
Conv2d-203 [1, 48, 1, 1] [48, 1152, 1, 1] 55344 (55296 + 48) True True
SiLU-204 [1, 48, 1, 1]
Conv2d-205 [1, 1152, 1, 1] [1152, 48, 1, 1] 56448 (55296 + 1152) True True
Sigmoid-206 [1, 1152, 1, 1]
Conv2d-207 [1, 320, 7, 7] [320, 1152, 1, 1] 368640 (368640+0) True
BatchNorm2d-208 [1, 320, 7, 7] [320] 640 (320 + 320) True True
Conv2d-209 [1, 1280, 7, 7] [1280, 320, 1, 1] 409600 (409600+0) True
BatchNorm2d-210 [1, 1280, 7, 7] [1280] 2560 (1280 + 1280) True True
SiLU-211 [1, 1280, 7, 7]
AdaptiveAvgPool2d-212 [1, 1280, 1, 1]
Dropout-213 [1, 1280]
Linear-214 [1, 1000] [1000, 1280] 1281000 (1280000 + 1000) True True
______________________________________________________________________________________________________________________________________________________
Total parameters 5,288,548
Total Non-Trainable parameters 0
Total Trainable parameters 5,288,548