Pytorch: Matrix multiplication (@) with max instead of sum

Is there a way in Pytorch to do the following (or is there a general mathematical term for this):
Assume normal matrix multiplication (torch.mm):
M3[i,k] = sum_j(M1[i,j] * M2[j,k]) size: M1: a×b; M2 b× c
Now I would like to replace the sum by max :
M3[i,k] = max_j(M1[i,j] * M2[j,k])
As you can see it is completely parallel to the above, just we take max over all j and not the sum.
NOTE: the size of M1 and M2 can be extremely large (e.g. M1 of size: 128×10000 or in this form (2×64)×(100×100)). In this case, a 'for-loop is not acceptable because of running time. Considering the memory issue, we can NOT do (M1.unsqueeze(2)*M2.unsuqeeze(0)).max(dim=1) either, a method mentioned in the ref link.
Ref link: https://stackoverflow.com/questions/41164305/numpy-dot-product-with-max-instead-of-sum

Thanks in advance!

There are many examples for “matmul in cuda” (e.g. from NVidia but also others), you could write your own kernel. To get the derivative, I’d recommend writing a function that keeps the indices of the maxima (similar to torch.max with dimension) and use the corresponding lookup in the backward (you might be able to use “native” PyTorch for that.

Best regards

Thomas

@yclin

Did you find a solution to this in pytorch?