Confusion about nested modules and shared parameters

I am confused about the relationship between assign a module as another module’s class member vs. reused the module. Supposed I want to share word embedding between two networks, I used to write code as the following:


class Model(nn.Module):
    def __init__():
        self.embedding = nn.Embedding(10000, 200)
        self.net_a = SubModule(self.embedding)
        self.net_b = SubModule(self.embedding)

   def forward(input):
       return self.net_a(input) + self.net_b(input)

class SubModule(nn.Module):
    def __init___(embedding):
        self.embedding = embedding
        self.fc = nn.Linear(200, 200)

   def forward(self, input):
        return self.fc(self.embedding(input))

In the above code, is the embedding correctly shared between net_a and net_b? It seems that the constructor of SubModule create another copy of the embedding so that there are three embedding matrix stored in the model?

2 Likes

Yes. Its a correct way of sharing the module.

After correcting some obvious bugs in your code (super class __init__(), indentation, etc), below is the working code. You can use model.named_parameters() to check the learnable parameters:

import torch
import torch.nn as nn


class SubModule(nn.Module):
    def __init__(self, embedding):
        super(SubModule, self).__init__()
        self.embedding = embedding
        self.fc = nn.Linear(200, 200)

    def forward(self, input):
        return self.fc(self.embedding(input))


class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.embedding = nn.Embedding(10000, 200)
        self.net_a = SubModule(self.embedding)
        self.net_b = SubModule(self.embedding)

    def forward(self, input):
        return self.net_a(input) + self.net_b(input)

m = Model()
for n, p in m.named_parameters():
    print(n, p.shape)

output:

('embedding.weight', (10000, 200))
('net_a.fc.weight', (200, 200))
('net_a.fc.bias', (200,))
('net_b.fc.weight', (200, 200))
('net_b.fc.bias', (200,))

1 Like

Thanks! However, if you print the keys in m.state_dict(), you will find:

['embedding.weight',
 'net_a.embedding.weight',
 'net_a.fc.weight',
 'net_a.fc.bias',
 'net_b.embedding.weight',
 'net_b.fc.weight',
 'net_b.fc.bias']

Are all the embedding weights above tied?

Good point!
I just verified that it is handled in pytorch to still share the weights such that changing the embedding weight in one submodule (eg., net_a) affects the main embedding weight and other submodule (eg., net_b).

4 Likes

Hi, nice idea.
I also had a similar question, like this kind of embedding sharing between networks also applies to share nn.Sequential modules? Like can nn.Sequential( ) based ‘lambdas’ (with their weights inside nn.Sequential()) be shared across different objects of a class

For e.g. instead of self.embedding, what if I wanted something like:

import torch
import torch.nn as nn

class SubModule(nn.Module): 

    def __init__(self, net_s):
        super(SubModule, self).__init__()
        self.nets = lambda : nn.Sequential(nn.Linear(100,200), nn.ReLU())

    def forward(self, input):
        return self.nets(input)

class Model(nn.Module):
   
    def __init__(self, x, net_s):
        super(Model,self).__init__()
        self.nets1 = SubModule(net_s)
        self.nets2 = SubModule(net_s)`
    def forward(self, input):
        return self.net_a(input) + self.net_b(input)


Do nets1 and nets2 use the same sequential lambda defined in SubModule (with the same weights)??

Or am I missing something?