Get main diagonal of two large matrix multiplication

I am trying to get the main diagonal from the multiplication of two large matrices. Here is my implementation:

def col_wise_mul(m1, m2):
    result = torch.zeros(0)
    for i in range(m1.shape[1]):
        v1 = m1[:, i, :]
        v2 = m2[:, i]
        v = torch.matmul(v1, v2).unsqueeze(1)
        result = torch.cat((result, v), dim=1)
    return result

I know that I could multiply two matrices first and then get the diagonal like below:

x = torch.diagonal(torch.matmul(x_feature, label_feature), offset=0).transpose(0, 1)

As the two matrices are very large, I am wondering is there any improvements that could done to my implementation (basically I want to avoid loop as this part will happen in the forward)?

Thanks!

1 Like

Hi,

I think the best way to do this is to rewrite (assuming 2D matrices) matmul(m1, m2).diag() to (m1 * m2.t()).sum(dim=1).
You can handle the extra batch dimension if you have one by just changing the the dimension you sum over.

2 Likes

Thanks! It really helps!