How to create model with sharing weight?

(Zhun Zhong) #1

I want to create a model with sharing weights, for example: given two input A, B, the first 3 NN layers share the same weights, and the next 2 NN layers are for A, B respectively.

How to create such model, and perform optimally?

(Adam Paszke) #2

EDIT: we do support sharing Parameters between modules, but it’s recommended to decompose your model into many pieces that don’t share parameters if possible.

We don’t support using the same Parameters in many modules. Just reuse the base for two inputs:

class MyModel(nn.Module):
    def __init__(self):
        self.base = ...
        self.head_A = ...
        self.head_B = ...

    def forward(self, input1, input2):
        return self.head_A(self.base(input1)), self.head_B(self.base(input2))

Calling a layer multiple times will produce the same weights?
How to share weights between two nets?
(Vladimir) #4

in your example, what will happen to gradients of self.base? will they be calculated taking into account both input1 and input2?

(Adam Paszke) #5

Yes, you can use the same module multiple times during forward.

(James Bradbury) #6

There are lots of cases where you can’t just reuse a Module but you still want to share parameters (e.g. in language modeling you sometimes want your word embeddings and output linear layers to share weight matrices). I thought reusing Parameters was ok? It’s used in the PTBLM example and it’s something people will keep doing (and expect to work) unless it throws an error.

(Adam Paszke) #7

Yeah they are supported, sorry for this. But it’s still considered better practice to not do it. I’ve updated the answer.

(김기동) #8

In this code, he make share modules(‘G_block_share/D_block_share’) out of class, and then use these share modules in different two classes(‘Generator A&B or Discriminator A&B)…

This code is right way to share weights between two generator/discriminators?

(N Hunter) #9

Could you please tell us why it is better to not do it? Thanks

(Zhang Heng) #10

Dear Apaszke, thank you for your updates! But I am still a little confused about your answer. Like in your example, you have 3 modules (base, headA, headB), but how could you decompose them into pieces that don’t share parameters? Looking forward to your answer, please! Thank you for your attention.