Is taking the index of a maximum element and then using that element in further operations a differentiable process?

Hi. I have a model where I will need to use one element from a matrix of elements. I computed a softmax probability distribution over this matrix. Now, i would like to extract the element that corresponds to the maximum probability from the softmax, and then use that element further on:

_, idx = torch.max(softmax)
desired_element = matrix[idx]

and then the desired_element is used for further operations in the model.

Is that operation differentiable?



Gradients will flow back from desired_element to matrix as you only take one element out of a matrix which is differentiable.
Gradients won’t flow back towards softmax though as indexing is not differentiable wrt the index and the argmax operation is not differentiable either.


Thank you @albanD for your reply. So If i had some learnable weights (W) before the softmax operation as follows:

_, idx = torch.max(softmax(A.W.B))
desired_element = matrix[idx]

Then gradients won’t flow through W?

No they won’t.
You can think about the gradients for idx as follow: how a small change in idx would change the output value? Well idx is an index, so it’s discrete. So gradients don’t make sense here :confused:


Thank you @albanD very much for your answer. What if i take a weighted sum of all the elements in matrix (softmax) but then multiplied the those weights with a mask generated based on the idx, as follows:

weights = softmax(A.W.B)
_, idx = torch.max(weights)
mask = torch.ones_like(weights)
mask = mask[weights<idx] = 0
weighted_sum = torch.sum(weights * mask * matrix)

This would basically achieve the same purpose as before, because ignored elements are multiplied with 0. So Would that be differentiable?

Doing this you will get gradients back to your W, but it’s not exactly the same function as now you multiply your entry with weights. Not sure how stable this function will be during training though.
Also in your example, I think you want the maximum value out of the max op, not the index of the maximum value.

Hi @albanD and thanks for for fast answers! I actually just want to take the maximum element from matrix and ignore everything else. But i still need the gradients to flow through W because I need it to learn how to choose the maximum.

outputs = softmax(A.W.B)
value, idx = torch.max(outputs)
mask = torch.ones_like(outputs)
mask[outputs<value] = 0
mask[outputs>=value] = 1
outputs = outputs * mask
weighted_sum = torch.sum(outputs * matrix)

This would basically keep only the maximum element. This would be fine I guess? We are just changing the output tensor of a specific layer.

small caveat, argmax is not differentiable only wrt to the index.

1 Like

But argmax only returns the index :wink: Only max returns the value.


Meaningful explanation!!!