Correct way to provide a general template to a model in PyTorch

I was working on a project and I had this question in mind out of curiosity. It’d be great if some light could be shed on this.

So, I was thinking of creating an abstract base class for defining networks using PyTorch for use in a specific system. So, users could sub-class this base class and define their models.

I was interested in calculating the names of each layer and their number of trainable parameters for each model.
Obviously, I could put in an abstractmethod so that each user could have a flexibility in defining their networks and they could have the responsibility to calculate the number of parameters and names of each layers in their networks.

However, is there a design pattern or approach I could follow to make sure that instead of putting that burden on the user to calculate these things, they only share the responsibility to define the network structure ? And I could write a general function which they could call directly in order to find out the names of each layer and number of parameters ?

For example, Inception and Alexnet are defined in very different ways due to obvious reasons. So, while for a Sequential model, I could write up a function, but then it would not directly work for Inception. Would it ?

Is there any correct approach ?

1 Like

I would look at other github projects on how folks have done it. You could start with writing your model class and functions very specific for your problem and then abstract away elements where you could re-use. If you have some code to share, you might get better suggestions on your next step.

It appears that I cannot attach files here. So, I am gonna provide the code samples over here.

The following is an Abstract Base Class for constructing networks. I call this file as

import abc
import logging
from collections import OrderedDict
import torch
from torch.autograd import Variable
import pandas as pd
from utils.timing import timeit

class BaseNetwork(metaclass=abc.ABCMeta):
    def __init__(self, name):
        self._name = name
        self._features = self.definition()

    def definition(self):

    def name(self):
        return self._name

    def summary(self, input_size=[3, 224, 224]):
        def register_hook(module):
            def hook(module, input, output):
                class_name = str(module.__class__).split('.')[-1].split("'")[0]
                module_idx = len(summary)

                m_key = '%s-%i' % (class_name, module_idx + 1)
                summary[m_key] = OrderedDict()
                summary[m_key]['input_shape'] = list(input[0].size())
                summary[m_key]['input_shape'][0] = -1
                summary[m_key]['output_shape'] = list(output.size())
                summary[m_key]['output_shape'][0] = -1

                params = 0
                if hasattr(module, 'weight'):
                    params +=
                    if module.weight.requires_grad:
                        summary[m_key]['trainable'] = True
                        summary[m_key]['trainable'] = False
                    summary[m_key]['trainable'] = False
                if hasattr(module, 'bias'):
                    params +=
                summary[m_key]['nb_params'] = params

            if not isinstance(module, torch.nn.Sequential) and \
                    not isinstance(module, torch.nn.ModuleList) and \
                    not (module == model):

        # check if there are multiple inputs to the network
        if isinstance(input_size[0], (list, tuple)):
            x = [Variable(torch.rand(1, *in_size)) for in_size in input_size]
            x = Variable(torch.rand(1, *input_size))

        # create properties
        summary = OrderedDict()
        hooks = []
        model = self._features
        # register hook
        # make a forward pass
        # remove these hooks
        for h in hooks:

        df = pd.DataFrame.from_dict(summary)
        df = df.T
        df.index.names = ['Layer_Name']
        logging.debug("A basic profile of the base network appears below for "
                      "an input of size {} (CHW format).\n".format(','.join(
                input_size)), df))
        return df

Then the following code is that of creating an AlexNet model as a subclass (

import logging
from .NetworkBase import BaseNetwork
import torch.nn as nn
from utils.timing import  timeit

class AlexNet(BaseNetwork):
    def __init__(self):
        logging.debug("The base network has been selected as Alexnet.")

    def definition(self):
        logging.debug("Defining the network.")
        model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2),
        return model

So, one can define the AlexNet and then simply call the summary() method defined in the abstract base class to display the summary.

But if there is a user who defines Inceptionv3 as it is defined here, this approach won’t work because Inception is not a Sequential model and moreover, in the aforementioned link it is also defined through a number of class member variables.

So, what could be a correct approach to this issue if (just imagining !!) someone wants to write a production level code.

this does not work with class attributes.

import torch
from torch import nn as nn
from torch_utils import View, Λ

class Metric(nn.Module):
    latent_dim = None
    embed = None
    kernel = None

# regular conv, migrated from ConvLargeL2.
class Conv(Metric):
    def __init__(self, input_dim, latent_dim):
        self.latent_dim = latent_dim
        self.embed = nn.Sequential(
            nn.Conv2d(input_dim, 32, kernel_size=4, stride=2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.Conv2d(64, 64, kernel_size=4, stride=2),
            nn.Conv2d(64, 32, kernel_size=4, stride=2),
            nn.Linear(128, latent_dim),
        self.kernel = Λ(lambda a, b: (a - b).norm(2, dim=-1))

    def forward(self, x, x_prime):
        *b, C, H, W = x.shape
        *b_, C, H, W = x_prime.shape
        z_1, z_2 = torch.broadcast_tensors(
            self.embed(x.reshape(-1, C, H, W)).reshape(*b, self.latent_dim),
            self.embed(x_prime.reshape(-1, C, H, W)).reshape(*b_, self.latent_dim))
        *b, W = z_1.shape
        return self.kernel(z_1, z_2).reshape(*b, 1)

   net = Conv(...)

gives ‘None’.