Hi. I am trying to implement a network with a shared single encoder and two decoders. Both decoders are trained simultaneously on the same data. I have figured out that if I total the loss of both decoders and perform backpropogation, the gradients will be accumulated on the encoder as well. However, I am confused how validation will be performed for such network. Any help will be greatly appreciated. Thank you!
I’m not sure, why this a problem. I’ve never tried it, but that’s how I would do it. Below is a code snippet of a RNN-based encoder-decoder model
class RnnAttentionSeq2Seq(nn.Module):
def __init__(self, params, criterion):
super().__init__()
self.params = params
self.criterion = criterion
self.encoder = Encoder(params)
self.decoder = Decoder(params, self.criterion)
def forward(self, X, Y):
# Push through encoder
encoder_outputs, encoder_hidden = self.encoder(X)
# Push through decoder
loss = self.decoder(Y, encoder_hidden, encoder_outputs)
return loss
If I had 2 decoders, I would simply change it to:
class RnnAttentionSeq2Seq(nn.Module):
def __init__(self, params, criterion):
super().__init__()
self.params = params
self.criterion = criterion
self.encoder = Encoder(params)
self.decoder1 = Decoder(params, self.criterion)
self.decoder2 = Decoder(params, self.criterion)
def forward(self, X, Y1, Y2):
# Push through encoder
encoder_outputs, encoder_hidden = self.encoder(X)
# Push through decoders
loss1 = self.decoder1(Y1, encoder_hidden, encoder_outputs)
loss2 = self.decoder2(Y2, encoder_hidden, encoder_outputs)
return loss1 + loss2
``
The above is the right way to go I think. You might want something like t*loss1 + (1-t)*loss2 so you balance the two losses (t is a fraction between 0 and 1).
For validation, you have (after you train), the ability to utilize the model given some input. For example maybe loss 1 has to do with translation and loss 2 has to do with classification. Then, given an input, you’d get a BLEU for translation and some accuracy for the classification task (task 2). I’d imagine the tasks should be kind of related actually (the encoder parameters are shared so it should “gain” by being good on each task), but I have to think more. Let me know if this might be helpful to you …