Efficiency Optimization for Dot Product Multihead Attention

Hi, I want to perform a batched computation of dot product multi-head attention defined below:


Currently I’m using batched matrix multiplication for the computation like below

"""
- H_i of shape (N, K, 1, F)
- W of shape (N, K, F, F) where the first dimension is obtain with .repeat
- H_j of shape (N, K, F, 1)
- N for number of samples
- K for number of heads
- F for n above
"""

H_i @ W @ H_j 

The performance is not very good in terms of speed and gpu memory occupied. I’m wondering if there are better approaches for performing this computation.