Multi_head_attention_forward() and batch dimension index

In multi_head_attention_forward under torch.functional, there is a check to make sure the batch is the second index in the tensor. However, it calls linear() proceeding this, which requires the batch to be the first index of the input tensor. This does not seem to be correct from my understanding, please correct me if I am wrong. Thank you.

def multi_head_attention_forward(…):
assert list(query.size()) == [tgt_len, bsz, embed_dim]

q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])


Actually, Linear consider anything that is not the last dimension as batch. So this is not an issue here.

1 Like