# Get main diagonal of two large matrix multiplication

I am trying to get the main diagonal from the multiplication of two large matrices. Here is my implementation:

``````def col_wise_mul(m1, m2):
result = torch.zeros(0)
for i in range(m1.shape):
v1 = m1[:, i, :]
v2 = m2[:, i]
v = torch.matmul(v1, v2).unsqueeze(1)
result = torch.cat((result, v), dim=1)
return result
``````

I know that I could multiply two matrices first and then get the diagonal like below:

``````x = torch.diagonal(torch.matmul(x_feature, label_feature), offset=0).transpose(0, 1)
``````

As the two matrices are very large, I am wondering is there any improvements that could done to my implementation (basically I want to avoid loop as this part will happen in the forward)?

Thanks!

1 Like

Hi,

I think the best way to do this is to rewrite (assuming 2D matrices) `matmul(m1, m2).diag()` to `(m1 * m2.t()).sum(dim=1)`.
You can handle the extra batch dimension if you have one by just changing the the dimension you sum over.

2 Likes

Thanks! It really helps!