It appears that I cannot attach files here. So, I am gonna provide the code samples over here.
The following is an Abstract Base Class for constructing networks. I call this file as NetworkBase.py
import abc
import logging
from collections import OrderedDict
import torch
from torch.autograd import Variable
import pandas as pd
from utils.timing import timeit
class BaseNetwork(metaclass=abc.ABCMeta):
def __init__(self, name):
self._name = name
self._features = self.definition()
@abc.abstractmethod
def definition(self):
pass
def name(self):
return self._name
@timeit
def summary(self, input_size=[3, 224, 224]):
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split('.')[-1].split("'")[0]
module_idx = len(summary)
m_key = '%s-%i' % (class_name, module_idx + 1)
summary[m_key] = OrderedDict()
summary[m_key]['input_shape'] = list(input[0].size())
summary[m_key]['input_shape'][0] = -1
summary[m_key]['output_shape'] = list(output.size())
summary[m_key]['output_shape'][0] = -1
params = 0
if hasattr(module, 'weight'):
params += torch.prod(
torch.LongTensor(list(module.weight.size())))
if module.weight.requires_grad:
summary[m_key]['trainable'] = True
else:
summary[m_key]['trainable'] = False
else:
summary[m_key]['trainable'] = False
if hasattr(module, 'bias'):
params += torch.prod(
torch.LongTensor(list(module.bias.size())))
summary[m_key]['nb_params'] = params
if not isinstance(module, torch.nn.Sequential) and \
not isinstance(module, torch.nn.ModuleList) and \
not (module == model):
hooks.append(module.register_forward_hook(hook))
# check if there are multiple inputs to the network
if isinstance(input_size[0], (list, tuple)):
x = [Variable(torch.rand(1, *in_size)) for in_size in input_size]
else:
x = Variable(torch.rand(1, *input_size))
# create properties
summary = OrderedDict()
hooks = []
model = self._features
# register hook
model.apply(register_hook)
# make a forward pass
model(x)
# remove these hooks
for h in hooks:
h.remove()
df = pd.DataFrame.from_dict(summary)
df = df.T
df.index.names = ['Layer_Name']
logging.debug("A basic profile of the base network appears below for "
"an input of size {} (CHW format).\n".format(','.join(
map(str,
input_size)), df))
return df
Then the following code is that of creating an AlexNet model as a subclass (alexnet.py)
import logging
from .NetworkBase import BaseNetwork
import torch.nn as nn
from utils.timing import timeit
class AlexNet(BaseNetwork):
def __init__(self):
logging.debug("The base network has been selected as Alexnet.")
super(AlexNet,self).__init__("AlexNet")
@timeit
def definition(self):
logging.debug("Defining the network.")
model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
return model
So, one can define the AlexNet and then simply call the summary()
method defined in the abstract base class to display the summary.
But if there is a user who defines Inceptionv3 as it is defined here, this approach won’t work because Inception is not a Sequential model and moreover, in the aforementioned link it is also defined through a number of class member variables.
So, what could be a correct approach to this issue if (just imagining !!) someone wants to write a production level code.