Is there a way in Pytorch to do the following (or is there a general mathematical term for this):
Assume normal matrix multiplication (torch.mm):
M3[i,k] = sum_j(M1[i,j] * M2[j,k]) size: M1: a×b; M2 b× c
Now I would like to replace the sum by max :
M3[i,k] = max_j(M1[i,j] * M2[j,k])
As you can see it is completely parallel to the above, just we take max
over all j
and not the sum.
NOTE: the size of M1 and M2 can be extremely large (e.g. M1 of size: 128×10000 or in this form (2×64)×(100×100)). In this case, a 'for-loop is not acceptable because of running time. Considering the memory issue, we can NOT do (M1.unsqueeze(2)*M2.unsuqeeze(0)).max(dim=1) either, a method mentioned in the ref link.
Ref link: https://stackoverflow.com/questions/41164305/numpy-dot-product-with-max-instead-of-sum
Thanks in advance!