Transformers with scheduled sampling implementation

I am trying to implement a seq2seq model using transformers with scheduled sampling.
The method that I am trying to implement is as described Mihaylova 2019.

My question revolves around how to implement the two decoder.

I am doubting which of these two approaches would be the correct one (maybe none is):

  • Forward the sample twice through the decoder, first with the gold targets and second with a mix of gold and predictions.
  • Have two separate decoders, one for the first pass and one for the second, that share weights and biases.

Here is some pseudocode to illustrate each case.

First approach:

class TransformersScheduled(nn.Module):
    def __init__(self, encoder, decoder):
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, trg):
        src = self.encoder(src)
        prediction = self.decoder(src, trg)
        new_trg = self.mix_gold_pred(trg, prediction, 0.5)
        prediction = self.decoder(src, new_trg)
        return prediction

    def mix_gold_pred(trg, pred, factor):
        # mix gold standard labels with predicted labels
        return new_trg

## loss is calculated on the second prediction and backpropagation is applied

Secod approach:

class TransformersScheduled(nn.Module):
    def __init__(self, encoder, decoder1, decoder2):
        self.encoder = encoder
        self.decoder1 = decoder1
        self.decoder2 = decoder2
    def forward(self, src, trg):
        src = self.encoder(src)
        # copy the weights from decoder2 to decoder1
        self.decoder1.weights = self.decoder2.weights
        prediction = self.decoder1(src, trg)
        new_trg = self.mix_gold_pred(trg, prediction, 0.5)
        prediction = self.decoder2(src, new_trg)
        return prediction

    def mix_gold_pred(trg, pred, factor):
        # mix gold standard labels with predicted labels
        return new_trg

## loss is calculated on the second prediction and backpropagation is applied

In the first case, would there be a “double” backpropagation through the decoder since it is used twice during the forward?
In the second case, would there be backpropagation through the first decoder?

Thanks in advance for your time and help!