How the following two classes interacts?

I was following the example of SNLI classifier from pytorch official examples and found the following two classes.

class Bottle(nn.Module):

    def forward(self, input):
        if len(input.size()) <= 2:
            return super(Bottle, self).forward(input)
        size = input.size()[:2]
        out = super(Bottle, self).forward(input.view(size[0]*size[1], -1))
        return out.view(*size, -1)


class Linear(Bottle, nn.Linear):
    pass

I am not understanding the flow of execution when I use the Linear class instance. How the class Linear is associated with the class Bottle? I can understand Linear class inherits nn.Linear but what is the purpose of the first parameter in the class declaration - class Linear(Bottle, nn.Linear) ?

Can anyone share your insight?

1 Like

The Bottle layer is a custom layer that is designed to apply a Linear layer over 3D tensors by flattening out the tensor over batch and sequence length dimensions (or the first two dimensions generically).

The new Linear layer they define inherits from both nn.Linear and Bottle. This means that Linear has its init and other methods the same as nn.Linear and only the forward method is redefined as the forward method in Bottle.

@jekbradbury, do you think that it would be better if such complexities (multiple inheritance etc.,) are not there in intro examples?

2 Likes

@pranav
I guess this is more of a python question - But, doesn’t Bottle inherit from nn.Module alone?
Bottle.__bases__ returns (<class 'torch.nn.modules.module.Module'>,). So how does super(Bottle, self).forward(...) call Linear's forward (which is turn falls back to nn.Linear's forward) ?

Shouldn’t it call nn.Module's forward which should raise a Not Implemented error?

1 Like

In the model that they subsequently write, they use the Linear layer that they defined through multiple inheritance and not the Bottle layer.

1 Like

This is the standard Python way to implement a “mixin”, although it’s a little confusing for people who are used to Java/C++ inheritance. Python’s super function is poorly named; it should really be something like next_in_method_resolution_order; if you check out help(Linear) at the interactive prompt it will tell you that its MRE is Linear, Bottle, nn.Linear, nn.Module, object, so calling Linear.forward will call Bottle.forward and the super call inside that method will call nn.Linear.forward.

2 Likes

I believe its best if we do not have such concepts at all - just simple Python and letting PyTorch/Model take the center stage

1 Like