E.g. think of an encoder-decoder architecture.
When you do generation, you will frequently have different queries for the same keys and values based on encoder outputs. For training, that can be a single
MultiheadAttention.forward call, but for generation, you would call
MultiheadAttention.forward for every decoder step. But it looks like this will do the same projection on keys and values over and over in every decoder step? This looks like a lot of redundant computation? Is PyTorch somehow clever to automatically optimize this away? Or is
MultiheadAttention.forward not used for generation?
I checked both the
torch.nn.MultiheadAttention and also looked at Fairseq
MultiheadAttention (here) and they both seem to have this API.
This is very confusing to me. Why would you do this so inefficient? Or do I misunderstand something?