Is there a way in Pytorch to do the following (or is there a general mathematical term for this):

Assume normal matrix multiplication (torch.mm):

M3[i,k] = sum_j(M1[i,j] * M2[j,k]) size: M1: a×b; M2 b× c

Now I would like to replace the sum by max :

M3[i,k] = max_j(M1[i,j] * M2[j,k])

As you can see it is completely parallel to the above, just we take `max`

over all `j`

and not the sum.

NOTE: the size of M1 and M2 can be extremely large (e.g. M1 of size: 128×10000 or in this form (2×64)×(100×100)). In this case, a 'for-loop is not acceptable because of running time. Considering the memory issue, we can NOT do (M1.unsqueeze(2)*M2.unsuqeeze(0)).max(dim=1) either, a method mentioned in the ref link.

Ref link: https://stackoverflow.com/questions/41164305/numpy-dot-product-with-max-instead-of-sum

Thanks in advance!