Model Summary from Pytorch

I am trying to obtain in Pytorch something similar as a Model.Summary from Keras. I am using the following code but I need to obtain the summary as well

class MaskCnnModel(ImageClassificationBase):
def init(self):
super().init() ##Se utiliza para no tener que volver a escribir el codigo de la clase anterior, en este caso la de ImageClassificationBase
self.network = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 256 x 64 x 64

        nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2), # output: 384 x 32 x 32

        nn.Conv2d(384, 512, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2), # output: 512 x 16 x 16
        
        nn.Flatten(), # toma el batch size, en este caso 344x10x10 y lo ordenamos en 2048 outputs y de ahi reducimos
        nn.Linear(512*16*16, 2048),  #hasta 2, que son las clases que tenemos
        nn.ReLU(),
        nn.Linear(2048, 1024),
        nn.ReLU(),
        nn.Linear(1024, 2))
    
def forward(self, xb):
    return self.network(xb)

You could have a look at e.g. torchinfo.

1 Like