MultiheadAttention with multiple batch dimensions

For my use-case, the input dimensions are (N, M, L, E), similar to the expected format in nn.MultiheadAttention but with an additional “batch” dimension M. This 4D tensor is unfortunately not accepted by nn.MultiheadAttention.

How can I implement self/cross attention to be over the second last dimension in a 4D tensor? I’m hoping for an interface similar to nn.Linear which accepts (*, E) dimension tensors.

Ideally, the solution shouldn’t involve flattening the N and M dimension into a single dimension as the output must be in a 4D tensor format, unless there is an efficient way to reconstruct the 4D structure from 3D with the appropriate attention masks.

Would it work for you to just vmap the model’s forward pass?

import torch
from torch import nn, vmap

N, M, L, E = 6, 4, 20, 8
model = nn.MultiheadAttention(embed_dim=E, num_heads=2, batch_first=True)

query = torch.randn((N, M, L, E))
key = torch.randn((N, M, L, E))
value = torch.randn((N, M, L, E))

attn_output, attn_output_weights = vmap(model)(query, key, value)

print(attn_output.shape)
print(attn_output_weights.shape)

Output:

torch.Size([6, 4, 20, 8])
torch.Size([6, 4, 20, 20])
2 Likes

Yes, it works great! I read up on torch.vmap, and it seems to leverage GPU parallelism with full gradient support. Thank you!

1 Like