Super().__init__() vs super(classname,self).__init__()

I have the doubt regarding what exactly does the super do?
In the below example for class BertPooler the init of base class is called without any parameter to super , While in the class TypeClassifier the super is initialised by super(TypeClassifier, self)

How does both are different and what is the role of it?

Thank you.

class BertPooler(torch.nn.Module):
    def __init__(self, config):
        self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = torch.nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

class TypeClassifier(torch.nn.Module):
    def __init__(self, model, n_labels):
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        super(TypeClassifier, self).__init__()
        self.bert_model = model


when you define your classes like this:

class MyClass(ClassToInheritFrom):

You are inheriting from the Base class. In my example would be the ClassToInheritFrom class. In your examples, both classes inherit from torch.nn.Module.

Now this torch.nn.Module has its own methods (like __init__()). So if you do NOT define these methods in your own class, they are still available, because the parent class has them, but they do only the stuff the base class has defined.

However, if you DO write your own __init__() method inside your class, the one from the base gets overwritten. You can then use super() to access the base class and then call the method so that both the stuff from the base as well as your stuff gets done.

In your case, they are the same. In previous versions of python you HAD TO to it like this ↓

But in python 3 it changed and you can use ↓ to make it easier.

You can also have multiple inheritance. Here is a little dumb example

class A:
    def __init__(self):
        print("This is class A")

class B:
    def __init__(self):
        print("This is class B")

class C(A, B):
    def __init__(self):
        # This will look for the first base class that implements this method

        # You can also explicitly call one of the base classes

        # This is what I wanted to do, extra to the base class
        print("This is class C")

c = C()
# Output:
#This is class A
#This is class B
#This is class C

Maybe this can help you ↓ understand more about inheritance

Hope this helps :smile:

1 Like