I am trying to implement a model based on the architecture in Scheduled Sampling for Transformers, and I’m getting lost in the details. Figure 1 in that paper is an excellent illustration of what they’re doing, in the context of a Transformer model:

They make a single encoder taking inputs and providing an encoder output.

They make the first of two decoders, taking the encoder outputs and conditioning on ground truth/teacher forcing data, and provide token estimates based on that teacher forcing data.

They make the second of two decoders, taking identical encoder inputs but now conditioning on a tokenbytoken probabilistic mixture of the ground truth data and the previously generated estimates.
(This is a convolutional variant of Scheduled Sampling, of course.)
Here is the critical trick: The two decoders share identical weights, but in the words of the paper, they: “Only backpropagate through the decoder which makes the final predictions, based on mix between the [ground truth] target and the model predictions.”
I’m a little confused as to how to meet both of those criteria with PyTorch.

I think I can do this by making two distinct generator objects, setting
requires_grad = False
for the first decoder, and copying the weights from the second at each step is this correct? 
But I feel like I’m missing some clever way to do this all with only one copy of the decoder resident in memory.