Efficiently Generate Sequence from Transformer Decoder

I’m currently using a Transformer decoder as an autoregressive model to generate a sequence. (So each element within a sequence depends on all the previously generated elements)

Now if I want to generate a sequence, I have to generate each element one by one in sequence. When I use the forward pass of the transformer decoder, the embeddings for all the previous elements in the sequence are always recomputed, but I actually only need to compute the very last element to attach to the sequence (since the previous elements are unchanged).

Is there any way to do this more efficiently in PyTorch using the MultiheadAttention/TransformerDecoder module? (Let me know if my question is unclear)