How is add_bias_kv implemented in MultiHeadAttention?

Hello,

I’m trying to gain more insight into how Pytorch implements the MultiheadAttention layer. Specifically, what does the add_bias_kv parameter do in the code? I see that they instantiate self.bias_k and self.bias_v in addition to in_proj_bias (which is determined by the bias param). I found the low-level implementation here: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp#L321 - However, it just uses qkv_bias (which I believe is the in_proj_bias?). Can anyone point me to where I can find more specifics on this?

Thank you!

In case it helps someone in the future, I found the implementation here: https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L5305. By add_kv_bias, it seems Pytorch actually means a concatenation operation, to the sequence length dimension. If you have k of dimensions [B, L, E] (batch, sequence length, embedding dim), you repeat the bias_k (or bias_v) along the batch dimension, and concatenate it along the sequence length dimension. Specifically, the bias_k of [1, 1, E] is repeated to be [B, 1, E]. Then it is concatenated to k of [B, L, E], with axis=1 (the repeat, concatenation axes will vary depending on whether batch dimension is first or not). After that operation, you will end up with a different sequence length [B, L+1, E]. You can see that they update the source length sequence after these adjustments, here: https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L5351.

Thank you so much @lnair for the info!

However, now I really wonder what’s the difference between add_bias_kv and add_zero_attn arguments: MultiheadAttention — PyTorch 2.0 documentation

To me, both sound very similar. The only apparent differences are that:

  1. add_bias_kv seems to be concat alongside dim=0 (batch dimension?) while add_zero_attn is done at dim=1 (sequence dimension?)
  2. add_bias_kv adds learnable parameters, while add_zero_attn fills with zeros.

This is the only info I have been able to find about add_zero_attn. But then I have no idea what purpose add_bias_kv would serve…

I don’t know if anyone would have more insights on this.