How do I update the network selectively?

Hi, I am designing an autoencoder with multiple decoders. And I want to update a selected decoder during the training. For example, below is a toy example of the network that has two decoders and a shared encoder.

class myModel(nn.module):
    def __init__(self):
        super(myModel, self).__init__()
        self.ENC = nn.ModuleList()
        self.ENC.append(nn.Conv1d(1,1,kernel_size=2))
        self.DEC1 = nn.ModuleList()
        self.DEC1.append(nn.Conv1d(1,1,kernel_size=2))
        self.DEC2 = nn.ModuleList()
        self.DEC2.append(nn.Conv1d(1,1,kernel_size=2))
    def forward(self, x, decoder_idx):
        for f in self.ENC:
            x = f(x)
        if idx == 1:
            for f in self.DEC1:
                x = f(x)
        elif idx == 2:
            for f in self.DEC2:
                x = f(x)
        return x

model = myModel()

out = model(x, 1)
optimizer.zero_grad()
loss.backward()
optimizer.step()

out = model(x, 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()

When I testify with out = myModel(x, 1), it only updates parameters of ENC and DEC1 as I expected. However, when I try out = myModel(x, 2) after the previous process, it updates ENC and both DEC1 and DEC2. How can I update them selectively in the continuous process?

2 Likes

This issue might be related to your optimizer.
If you use an optimizer with a momentum or some averaged statistics and pass all parameters to it, your inactive parameters might still be updated.
Which optimizer are you currently using?

I think the safest way would be to create two (or even three) optimizers and pass the corresponding parameters to each of them.

3 Likes

Thank you for your reply. I am currently using Adam. Yes, probably this can affect.
Do you think using multiple optimizers is the only option? It means I have to use 16 optimizers for 16 decoders.

The alternative would be to use an optimizer without these properties, e.g. plain SGD.
Currently I don’t know another clean approach.

1 Like

Okay. Thank you for your comment :slight_smile:

That’s a great question. Unfortunately, I don’t have any good answer to that, but I would be very interested to hear whether @ptrblck 's suggestion, using SGD, addresses the problem in you case. (And in case you find another good solution in future, please share!).

1 Like

I created a small gist showing this effect.
There is some boilerplate code, but it shows the effect of Adam and SGD on the parameter updates.

1 Like

Thanks a lot, that’s great!

When I see it correctly though, there’s no update of inactive parameters, i.e., when I run the code in PyTorch 0.4, for the decoder 1 branch, I get

Dec2 grad: None
Diff in enc.weight
Diff in enc.bias
Diff in dec1.weight
Diff in dec1.bias

And vice versa, I get the following for the decoder 2 branch:

Diff in enc.weight
Diff in enc.bias
Diff in dec2.weight
Diff in dec2.bias

So, using something like
itertools.chain(model.enc.parameters(), model.dec1.parameters())
instead of the if/else and model.parameters() seems overkill I guess. Not sure if this could have been a problem in older PyTorch versions though.

2 Likes

Well, this is the output if SGD is used.
If you set use_adam=True (top of the script), the second update changes all parameters.

2 Likes

Actually, I was trying to replicate this research. https://arxiv.org/pdf/1805.07848.pdf
In this paper, the authors trained multiple decoders that share an encoder using ADAM.
I don’t know how they avoided this problem.

I updated the gist with two separate optimizers using Adam.
It should work in this way.
I’m not sure either, how the original paper was implemented.

Sorry, my bad, I didn’t look correctly. In the first pass, it looked okay, but in the second pass the problem occurred – probably because the Adam params are zero in the first pass.

Using two separate optimizers like shown in the update gist sounds like the best idea in this case