How to perform the MaxSim operator leveraging torch procedures?

Let T and L be two batches of matrices (MxN) and a function f(ti,lj) that calculates a score for matrices ti and lj. For instance, if

T, L= torch.rand(4,3,2), torch.rand(4,3,2)
# T = tensor([[[0.0017, 0.5781],
#          [0.8136, 0.5971],
#          [0.7697, 0.0795]],

#         [[0.2794, 0.7285],
#          [0.1528, 0.8503],
#          [0.9714, 0.1060]],

#         [[0.6907, 0.8831],
#          [0.4691, 0.4254],
#          [0.2539, 0.7538]],

#         [[0.3717, 0.2229],
#          [0.6134, 0.4810],
#          [0.7595, 0.9449]]])

and the score function is defined as shown in the following code snippet:

def score(ti, lj):
    """MaxSim score of matrix ti and lj
    """
    m = torch.matmul(ti, torch.transpose(lj, 0, 1))
    return torch.sum(torch.max(m, 1).values, dim=-1)

How to return a score matrix S, where S[i,j] represents the score between T[i] and L[j]?

#S = tensor([[2.3405, 2.2594, 2.0989, 1.6450],
#            [2.5939, 2.4186, 2.3946, 2.0648],
#            [2.9447, 2.3652, 2.3829, 2.1536],
#            [2.8195, 2.3105, 2.2563, 1.8388]])

NOTE: This operation must be differentiable.

I think you want to do batch matrix multiplication first of all no? This would be more efficient. See the code below. The max is usually approximated in a differentiable way by the softmax, so maybe you can do the softmax below. What is the objective etc?

# Get something of dimension B X M X M where T ~ M X N and L ~ M X N and B is the batch size.
m = torch.bmm(T, torch.transpose(L, 1, 2))

# For each matrix, the row sums are 1. 
nn.Softmax(dim=2)(m).shape

# Then, you'd want to do a cross entropy loss across the m classes, where the true label is the (actual) max? 

BTW - ReLU technically has no derivative but we can still use it bc it has a SUB-gradient.

The max of two things can be expressed as

f(x, y) = (x + y + |x - y|) / 2 and I’m sure you can express the max of K things in the same way. Basically, using RELUs and absolute values. These things can be used, since they have subgradients.