nn.Module classes instantiation and weights sharing

I am dealing with two modules, both of them contain different sub-modules and one of these must be the same, as follows:

class Module1(nn.Module):
    def __init__(self, net1, net_shared):
        super(Module1,self).__init__()
        self.net1 = net1
        self.net_shared = net_shared
    def forward(self, x):
        ...
        return y

class Module2(nn.Module):
    def __init__(self, net2, net_shared):
        super(Module2,self).__init__()
        self.net2 = net2
        self.net_shared = net_shared
    def forward(self, x):
        ...
        return y

net1 = nn.Sequential()
net2 = nn.Sequential()
net_shared = nn.Sequential()

module1 = Module1(net1, net_shared)
module2 = Module2(net2, net_shared)

net1, net2, net_shared are preinstantiated nn.Sequential modules that will be passed to the constructor of Module1 and Module2. So, does net_shared will be treated as a single net? So that backpropagating through Module2, net_shared weights in Module1 will be updated, or by instantiating two different classes it will be treated as two deep copies instead of a shallow one?

Hi 151!

Yes and yes.

No, nothing is being copied (other than the python references themselves).
The line of code, net_shared = nn.Sequential(), instantiates a single
instance of the class Sequential and sets the python reference net_shared
to refer to it. When, instantiating instances of Module1 and Module2, you
pass in net_shared, module1.net_shared and module2.net_shared get
created as python references that refer to the same object-instance as the
net_shared reference that is passed in.

There is only one (not counting net1 and net2) instance of Sequential
and it is indeed shared between module1 and module2 – no copies,
neither deep nor shallow are made.

You can probe this by looking at the object id()'s of your net_shareds and
you will see that they are the same:

>>> import torch
>>> torch.__version__
'1.13.0'
>>> net_shared = torch.nn.Sequential()
>>> id (net_shared)
2253492592912
>>> module1 = torch.nn.Sequential (net_shared)
>>> module2 = torch.nn.ModuleList ([net_shared])
>>> id (module1[0])
2253492592912
>>> id (module2[0])
2253492592912

Best.

K. Frank

1 Like

Thanks KFrank for the very clear explaination!