For my use-case, the input dimensions are (N, M, L, E), similar to the expected format in nn.MultiheadAttention but with an additional “batch” dimension M. This 4D tensor is unfortunately not accepted by nn.MultiheadAttention.
How can I implement self/cross attention to be over the second last dimension in a 4D tensor? I’m hoping for an interface similar to nn.Linear which accepts (*, E) dimension tensors.
Ideally, the solution shouldn’t involve flattening the N and M dimension into a single dimension as the output must be in a 4D tensor format, unless there is an efficient way to reconstruct the 4D structure from 3D with the appropriate attention masks.