Partial Backprop and Weight Update

I am trying to implement a model based on the architecture in Scheduled Sampling for Transformers, and I’m getting lost in the details. Figure 1 in that paper is an excellent illustration of what they’re doing, in the context of a Transformer model:

  • They make a single encoder taking inputs and providing an encoder output.

  • They make the first of two decoders, taking the encoder outputs and conditioning on ground truth/teacher forcing data, and provide token estimates based on that teacher forcing data.

  • They make the second of two decoders, taking identical encoder inputs but now conditioning on a token-by-token probabilistic mixture of the ground truth data and the previously generated estimates.

(This is a convolutional variant of Scheduled Sampling, of course.)

Here is the critical trick: The two decoders share identical weights, but in the words of the paper, they: “Only backpropagate through the decoder which makes the final predictions, based on mix between the [ground truth] target and the model predictions.”

I’m a little confused as to how to meet both of those criteria with PyTorch.

  1. I think I can do this by making two distinct generator objects, setting requires_grad = False for the first decoder, and copying the weights from the second at each step-- is this correct?

  2. But I feel like I’m missing some clever way to do this all with only one copy of the decoder resident in memory.