Hi, I am designing an autoencoder with multiple decoders. And I want to update a selected decoder during the training. For example, below is a toy example of the network that has two decoders and a shared encoder.
class myModel(nn.module):
def __init__(self):
super(myModel, self).__init__()
self.ENC = nn.ModuleList()
self.ENC.append(nn.Conv1d(1,1,kernel_size=2))
self.DEC1 = nn.ModuleList()
self.DEC1.append(nn.Conv1d(1,1,kernel_size=2))
self.DEC2 = nn.ModuleList()
self.DEC2.append(nn.Conv1d(1,1,kernel_size=2))
def forward(self, x, decoder_idx):
for f in self.ENC:
x = f(x)
if idx == 1:
for f in self.DEC1:
x = f(x)
elif idx == 2:
for f in self.DEC2:
x = f(x)
return x
model = myModel()
out = model(x, 1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
out = model(x, 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
When I testify with out = myModel(x, 1), it only updates parameters of ENC and DEC1 as I expected. However, when I try out = myModel(x, 2) after the previous process, it updates ENC and both DEC1 and DEC2. How can I update them selectively in the continuous process?
This issue might be related to your optimizer.
If you use an optimizer with a momentum or some averaged statistics and pass all parameters to it, your inactive parameters might still be updated.
Which optimizer are you currently using?
I think the safest way would be to create two (or even three) optimizers and pass the corresponding parameters to each of them.
Thank you for your reply. I am currently using Adam. Yes, probably this can affect.
Do you think using multiple optimizers is the only option? It means I have to use 16 optimizers for 16 decoders.
That’s a great question. Unfortunately, I don’t have any good answer to that, but I would be very interested to hear whether @ptrblck 's suggestion, using SGD, addresses the problem in you case. (And in case you find another good solution in future, please share!).
When I see it correctly though, there’s no update of inactive parameters, i.e., when I run the code in PyTorch 0.4, for the decoder 1 branch, I get
Dec2 grad: None
Diff in enc.weight
Diff in enc.bias
Diff in dec1.weight
Diff in dec1.bias
And vice versa, I get the following for the decoder 2 branch:
Diff in enc.weight
Diff in enc.bias
Diff in dec2.weight
Diff in dec2.bias
So, using something like itertools.chain(model.enc.parameters(), model.dec1.parameters())
instead of the if/else and model.parameters() seems overkill I guess. Not sure if this could have been a problem in older PyTorch versions though.
Actually, I was trying to replicate this research. https://arxiv.org/pdf/1805.07848.pdf
In this paper, the authors trained multiple decoders that share an encoder using ADAM.
I don’t know how they avoided this problem.
Sorry, my bad, I didn’t look correctly. In the first pass, it looked okay, but in the second pass the problem occurred – probably because the Adam params are zero in the first pass.
Using two separate optimizers like shown in the update gist sounds like the best idea in this case