Question on nn.MultiheadAttention

Hello everyone, I’m trying to code the Transformer for the first time using nn.MultiheadAttention. However, I’m very confused. Could someone help me with these two questions? Reference to the documentation of nn.MultiheadAttention

  1. Why is the batch size of query not the first dimension? (The size of query for nn.MultiheadAttention is [target length, batch size, embed dim]). I wish it is [batch size, target length, embed dim] instead. I thought that torch.utils.data.DataLoader always put batch size to be the first dimension, and most other neural networks also put batch size in the first dimension. What is the reason behind this? I would like to know in case I’m missing something.

  2. nn.MultiheadAttention returns attn_output and attn_output_weights. I would like to know why it returns attn_output_weights. Wouldn’t it be sufficient to just use the output “attn_output”? Do I need to use attn_output_weights somewhere in the Transformer?

3 Likes

UP! I am also interested on that.

  1. I think this order is needed for internal bmm operation, that does the main job. It is also consistent with RNNs’ default.
  2. as need_weights argument may suggest, weights can be ignored.

@ptrblck Hello. I am also very interested in this question. Setting the sequence dimension to the first dimension would make the matrix multiply extremely inefficient due to memory access issues.
Is this really because of compatibility with RNNs?
If completely designed from scratch, the most reasonable order would appear to be (N, E, S) instead of the current (S, N, E) order. This would allow the sequence length axes, which are the ones being multiplied in the (self-attention matrix @ value matrix) operation, to be contiguous. As this is the most computationally intensive, this operation should be the most heavily optimized.
cuDNN and cuBLAS optimizations mean that (N, S, E) order, which is called batch_first order in torchtext, will probably have similar performance, though this should be tested.