Hi, I want to perform a batched computation of dot product multi-head attention defined below:
Currently I’m using batched matrix multiplication for the computation like below
""" - H_i of shape (N, K, 1, F) - W of shape (N, K, F, F) where the first dimension is obtain with .repeat - H_j of shape (N, K, F, 1) - N for number of samples - K for number of heads - F for n above """ H_i @ W @ H_j
The performance is not very good in terms of speed and gpu memory occupied. I’m wondering if there are better approaches for performing this computation.