MultiheadAttention for cross attention with beam search (generation), redundant computation for keys/values?

E.g. think of an encoder-decoder architecture.

When you do generation, you will frequently have different queries for the same keys and values based on encoder outputs. For training, that can be a single MultiheadAttention.forward call, but for generation, you would call MultiheadAttention.forward for every decoder step. But it looks like this will do the same projection on keys and values over and over in every decoder step? This looks like a lot of redundant computation? Is PyTorch somehow clever to automatically optimize this away? Or is MultiheadAttention.forward not used for generation?

I checked both the torch.nn.MultiheadAttention and also looked at Fairseq MultiheadAttention (here) and they both seem to have this API.

This is very confusing to me. Why would you do this so inefficient? Or do I misunderstand something?