# Pytorch: Matrix multiplication (@) with max instead of sum

Is there a way in Pytorch to do the following (or is there a general mathematical term for this):
Assume normal matrix multiplication (torch.mm):
M3[i,k] = sum_j(M1[i,j] * M2[j,k]) size: M1: a×b; M2 b× c
Now I would like to replace the sum by max :
M3[i,k] = max_j(M1[i,j] * M2[j,k])
As you can see it is completely parallel to the above, just we take `max` over all `j` and not the sum.
NOTE: the size of M1 and M2 can be extremely large (e.g. M1 of size: 128×10000 or in this form (2×64)×(100×100)). In this case, a 'for-loop is not acceptable because of running time. Considering the memory issue, we can NOT do (M1.unsqueeze(2)*M2.unsuqeeze(0)).max(dim=1) either, a method mentioned in the ref link.
There are many examples for “matmul in cuda” (e.g. from NVidia but also others), you could write your own kernel. To get the derivative, I’d recommend writing a function that keeps the indices of the maxima (similar to `torch.max` with dimension) and use the corresponding lookup in the backward (you might be able to use “native” PyTorch for that.