I am a bit confused on why for batch size >1 in the MultiHeadAttention class, attention is calculated by passing the input as <batch_size*num_heads, seq_length, embed_dim> and later reshaped to < batch_size, num_heads, seq_length, embed_dim> to get average attention across heads. Shouldn’t the input be reshaped first to split by batch and then calculate attention?
I dont think I have understood the question. But this video might help torch.nn.MultiheadAttention
Purpose of multiple heads is not to parallelize computations. Multiple heads work similar to convolutional filters, each building it’s own logic. Data split you mentioned is not batch-wise split, basic attention from Attention is all you need splits Q,K,V into separate chunks feeding each chunk to different head.
Thank you for the explanation!