What is the base class that only models inherit from

I’m working on supporting automatic model detection/logging for PyTorch models for our Machine learning platform https://iko.ai and I want to know what is the base class for models that only models inherit from? i.e. I’m looking for the class X that passes this condition: If a class Y inherits from class X, Y is a PyTorch model.

Is it torch.nn.modules.module.Module ?

The base class is the Module class in torch.nn

Import torch
Import torch.nn as nn

class NNmodel(nn.Modules):
    def __init__(self, .....):
        super(NNmodel, self).__init__()

All the classes that inherits from torch.nn.Module are models ?

Yah
They are neural network models
Any neural network model made with pytorch always inherits from the Modules class.

What about torch.nn.CrossEntropyLoss class, it inherits from torch.nn.Module but it’s not a model, right ?

It’s a loss function

What I meant was that to build a neural network model in pytorch, u have to inherit from the torch.nn.Module class

Yeah, but I’m looking for the base class for models that only models inherit from. Is there such a class? Or any condition that only models fulfills?

The only class that models inherit from in pytorch is the torch.nn.Module class

OK, Thanks.

Do you see any other way to distinguish model classes from non-models classes? a common attribute/method?

I don’t Know if I understand ur question but if u mean distinguishing the class of an implemented model architecture from one that is not a model then the only difference is that model classes in pytorch inherit from torch.nn.Module.

If u are referring to ones that make up a NN model architecture like torch.nn.Linear, torch.nn.Conv2D etc and are trying to differentiate them from things like loss functions eg: torch.nn.CrossEntropy or torch.nn.NLLLoss etc then I guess u just different them by there names, by the names u know which can constitute a model architecture and which is a loss function

I’m referring to your second guess: how can I differentiate between classes that constitute an NN model and those which implement losses …

I want to write a python function to do that for me and clearly basing on class names is not a good way to do it.

Anyway, thank you very much for your help.

Hmmmm🤔

You want to write a python function that differentiate these the model constituents’ class from the loss classes?

Yup. The function is something like for now

def is_pytorch_model(obj):
    return isinstance(obj, torch.nn.Module)

It, clearly, consider losses and other torch.nn.Module subclasses as NN models which is not true.

Hmmm🤔

Have U tried differentiating them by the kind of values they return ?
The output of let’s say torch.nn.Conv2D is different from what torch.nn.CrossEntropy outputs shape wise and all…:man_shrugging:

I can’t do that because I’m doing a static analysis of the code using astroid.

For losses, there is actually a couple of attributes that could differentiate them from the other classes


class Net(nn.Module):
    # define nn
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 3)
        self.softmax = nn.Softmax(dim=1)
​
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.softmax(x)
        return x

net = Net()
criterion = nn.CrossEntropyLoss()
set(dir(criterion)) - set(dir(net))

>>> {'__constants__', 'ignore_index', 'reduction', 'weight'}

What are all the other things that should inherit from torch.nn.Module?

Well to my current knowledge only the model constituents and loss functions inherit from torch.nn.Module class

Hello Haroune and Henry!

I don’t think that there is a simple, one-stop-shopping way of doing
this that is completely reliable. Note that activations also inherit
from Module.

Losses, however, (appear to) inherit (possibly indirectly) from
torch.nn.modules.loss._Loss (which, in turn, inherits from
Module). You could look for that.

You might consider counting parameters() on the theory that any
self-respecting model has parameters (although some activation
functions also have parameters).

Also, you have to be clear about what you mean by a “model.”
Would you consider a single Linear to be a model? Arguably
it would be. How about a single Softmax(). This seems more
of a stretch, but how does it behave any differently than a
Linear? (One difference is that it doesn’t have any parameters.)

This script illustrates some of these points:

import torch
print (torch.__version__)
model_linear = torch.nn.Linear (3, 5)
model_sequential = torch.nn.Sequential ((torch.nn.Linear (4, 6)))
loss_mse = torch.nn.MSELoss()
loss_ce = torch.nn.CrossEntropyLoss()
act_softmax = torch.nn.Softmax()
act_prelu = torch.nn.PReLU()
print (sum (1 for _ in model_linear.parameters()))       # count parameters
print (sum (1 for _ in model_sequential.parameters()))   # count parameters
print (sum (1 for _ in loss_mse.parameters()))           # count parameters
print (sum (1 for _ in loss_ce.parameters()))            # count parameters
print (sum (1 for _ in act_softmax.parameters()))        # count parameters
print (sum (1 for _ in act_prelu.parameters()))          # count parameters
print (model_linear.__class__.__bases__)                 # immediate superclass
print (model_sequential.__class__.__bases__)             # immediate superclass
print (loss_mse.__class__.__bases__)                     # immediate superclass
print (loss_ce.__class__.__bases__)                      # immediate superclass
print (act_softmax.__class__.__bases__)                  # immediate superclass
print (act_prelu.__class__.__bases__)                    # immediate superclass
print (model_linear.__class__.__mro__)                   # full class hierarchy
print (model_sequential.__class__.__mro__)               # full class hierarchy
print (loss_mse.__class__.__mro__)                       # full class hierarchy
print (loss_ce.__class__.__mro__)                        # full class hierarchy
print (act_softmax.__class__.__mro__)                    # full class hierarchy
print (act_prelu.__class__.__mro__)                      # full class hierarchy

Here is the output:

>>> import torch
>>> print (torch.__version__)
1.6.0
>>> model_linear = torch.nn.Linear (3, 5)
>>> model_sequential = torch.nn.Sequential ((torch.nn.Linear (4, 6)))
>>> loss_mse = torch.nn.MSELoss()
>>> loss_ce = torch.nn.CrossEntropyLoss()
>>> act_softmax = torch.nn.Softmax()
>>> act_prelu = torch.nn.PReLU()
>>> print (sum (1 for _ in model_linear.parameters()))       # count parameters
2
>>> print (sum (1 for _ in model_sequential.parameters()))   # count parameters
2
>>> print (sum (1 for _ in loss_mse.parameters()))           # count parameters
0
>>> print (sum (1 for _ in loss_ce.parameters()))            # count parameters
0
>>> print (sum (1 for _ in act_softmax.parameters()))        # count parameters
0
>>> print (sum (1 for _ in act_prelu.parameters()))          # count parameters
1
>>> print (model_linear.__class__.__bases__)                 # immediate superclass
(<class 'torch.nn.modules.module.Module'>,)
>>> print (model_sequential.__class__.__bases__)             # immediate superclass
(<class 'torch.nn.modules.module.Module'>,)
>>> print (loss_mse.__class__.__bases__)                     # immediate superclass
(<class 'torch.nn.modules.loss._Loss'>,)
>>> print (loss_ce.__class__.__bases__)                      # immediate superclass
(<class 'torch.nn.modules.loss._WeightedLoss'>,)
>>> print (act_softmax.__class__.__bases__)                  # immediate superclass
(<class 'torch.nn.modules.module.Module'>,)
>>> print (act_prelu.__class__.__bases__)                    # immediate superclass
(<class 'torch.nn.modules.module.Module'>,)
>>> print (model_linear.__class__.__mro__)                   # full class hierarchy
(<class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.modules.module.Module'>, <class 'object'>)
>>> print (model_sequential.__class__.__mro__)               # full class hierarchy
(<class 'torch.nn.modules.container.Sequential'>, <class 'torch.nn.modules.module.Module'>, <class 'object'>)
>>> print (loss_mse.__class__.__mro__)                       # full class hierarchy
(<class 'torch.nn.modules.loss.MSELoss'>, <class 'torch.nn.modules.loss._Loss'>, <class 'torch.nn.modules.module.Module'>, <class 'object'>)
>>> print (loss_ce.__class__.__mro__)                        # full class hierarchy
(<class 'torch.nn.modules.loss.CrossEntropyLoss'>, <class 'torch.nn.modules.loss._WeightedLoss'>, <class 'torch.nn.modules.loss._Loss'>, <class 'torch.nn.modules.module.Module'>, <class 'object'>)
>>> print (act_softmax.__class__.__mro__)                    # full class hierarchy
(<class 'torch.nn.modules.activation.Softmax'>, <class 'torch.nn.modules.module.Module'>, <class 'object'>)
>>> print (act_prelu.__class__.__mro__)                      # full class hierarchy
(<class 'torch.nn.modules.activation.PReLU'>, <class 'torch.nn.modules.module.Module'>, <class 'object'>)

Best.

K. Frank

3 Likes