Understand attn_mask param in nn.MultiheadAttention module

I have a point cloud with n points and each point has d features. I want to use the attn_mask in nn.MultiheadAttention module to do local attention to only points in neighbors. Should I pass the adjacency matrix to achieve that target?

Well, I think this would work (but you should reverse the adjacency matrix, because True in attn_mask means masking this token), but NOT a good choice.

If you have N points, and every point has M neighbor, which M << N. The cost of this attention layer which is only within neighbors should be O(NxM). But if using the attn_mask directly to the attention layer, the cost will be O(NxN). That’s because the attn_mask just uses -inf to neutralize the softmax for masked tokens but still compute with them. If you have so many points (like 5000) but M is small (like 10), this will lead to unnecessary consumption.

I understand what you are concerned about. However, this is the only way I know to utilize the parallelization of GPU. Do you have any suggestions?

In 2D image processing, Neightborhood Attention has similar needs.

There are also some other methods in my memory. For example, grouping the N tokens and their M neighbors, then getting (N, M) tokens. Processing attention on these tokens by treating N as the batch size. This will require consideration of the trade-off between computational and memory effort.