I am trying to get the main diagonal from the multiplication of two large matrices. Here is my implementation:
def col_wise_mul(m1, m2):
result = torch.zeros(0)
for i in range(m1.shape[1]):
v1 = m1[:, i, :]
v2 = m2[:, i]
v = torch.matmul(v1, v2).unsqueeze(1)
result = torch.cat((result, v), dim=1)
return result
I know that I could multiply two matrices first and then get the diagonal like below:
x = torch.diagonal(torch.matmul(x_feature, label_feature), offset=0).transpose(0, 1)
As the two matrices are very large, I am wondering is there any improvements that could done to my implementation (basically I want to avoid loop as this part will happen in the forward)?
Thanks!