Understanding the output dimensionality for torch.nn.MultiheadAttention.forward

I want to implement a cross attention between 2 modalities. In my implementation, I set Q from modality A, and K and V from modality B. Modality A is used for a guidance by using cross attention, and the main operations are done in modality B.

Here is the example of my current implementation:

batch_size = 1
embedding_dims = 128
seqlen_A = 100
seqlen_B = 30

q = torch.randn(batch_size, seqlen_A, embedding_dims)
k = torch.randn(batch_size, seqlen_B, embedding_dims)
v = torch.randn(batch_size, seqlen_B, embedding_dims)

attn_out, attn_map attn(q,k,v)


And I notice the output dimensionality for attn_out is (1,100,128), which is the same as q’s dimensionality, not v’s.

My intuition of attention mechanism is that q and k are used to extract the relationship of each other and v is the actual value. That’s why I set Q with modality A, only used for guidance, and set K,V with modality B, which I mostly care about. But as attn_out has the same dimensionality as Q, not V, I am little lost.

Is my understanding about attention mechanism wrong? And how can I possibly implement cross attention of modality A and B, where the output should be same as latent variable of modality B?

As a reference, I have a string documentation for torch.nn.MultiheadAttention.forward below.

Signature:
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
need_weights: bool = True,
average_attn_weights: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]
Source:
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query: Query embeddings of shape :math:(L, E_q) for unbatched input, :math:(L, N, E_q) when batch_first=False
or :math:(N, L, E_q) when batch_first=True, where :math:L is the target sequence length,
:math:N is the batch size, and :math:E_q is the query embedding dimension embed_dim.
Queries are compared against key-value pairs to produce the output.
See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:(S, E_k) for unbatched input, :math:(S, N, E_k) when batch_first=False
or :math:(N, S, E_k) when batch_first=True, where :math:S is the source sequence length,
:math:N is the batch size, and :math:E_k is the key embedding dimension kdim.
See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:(S, E_v) for unbatched input, :math:(S, N, E_v) when
batch_first=False or :math:(N, S, E_v) when batch_first=True, where :math:S is the source
sequence length, :math:N is the batch size, and :math:E_v is the value embedding dimension vdim.
See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:(N, S) indicating which elements within key
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched query, shape should be :math:(S).
Binary and byte masks are supported.
For a binary mask, a True value indicates that the corresponding key value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding key value.
need_weights: If specified, returns attn_output_weights in addition to attn_outputs.
Default: True.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:(L, S) or :math:(N\cdot\text{num\_heads}, L, S), where :math:N is the batch size,
:math:L is the target sequence length, and :math:S is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a True value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
average_attn_weights: If true, indicates that the returned attn_weights should be averaged across
heads. Otherwise, attn_weights are provided separately per head. Note that this flag only has an
effect when need_weights=True. Default: True (i.e. average weights across heads)

Outputs:
- **attn_output** - Attention outputs of shape :math:(L, E) when input is unbatched,
:math:(L, N, E) when batch_first=False or :math:(N, L, E) when batch_first=True,
where :math:L is the target sequence length, :math:N is the batch size, and :math:E is the
embedding dimension embed_dim.
- **attn_output_weights** - Only returned when need_weights=True. If average_attn_weights=True,
returns attention weights averaged across heads of shape :math:(L, S) when input is unbatched or
:math:(N, L, S), where :math:N is the batch size, :math:L is the target sequence length, and
:math:S is the source sequence length. If average_attn_weights=False, returns attention weights per
head of shape :math:(\text{num\_heads}, L, S) when input is unbatched or :math:(N, \text{num\_heads}, L, S).

.. note::
batch_first argument is ignored for unbatched inputs.
"""


it says “**attn_output** - Attention outputs of shape :math:(L, E) when input is unbatched :math:(L, N, E) when batch_first=False`…”, and I don’t understand why it is not (S,E) or (S,N,E).

I ran example codes (as mentioned in the question)