Make further prediction with argmax in a end2end model


I want to first make a prediction with r_index = x.argmax(). The argmax() result r_index is the index of different rates. For example, there are 3 predefined rates: rates = torch.Tensor([0.1, 0.5, 0.7]).
Then I need this index for further prediction. output = rates[r_index] * ...

As argmax is not differentiable. How to make it work in pytorch? I want to do it in an end-to-end model.

Thanks a lot if anyone can help!

Hi Steve,

Could you brute force this by training a model of the type:
Linear(len(x), len_intermediary)
Linear(len_intermediary, 1)

And train it to output the desired result (e.g. 0.1 if argmax(x) == 0, 0.5 if argmax(x) == 1, and so forth). Then simply freeze these parameters and load them into your end2end model in the appropriate place.

In other words, perhaps the transformation you want – from argmax to specific value – can be learned by a series of fully connected layers. This would allow you to piggyback on the differentiable tools already available.

You probably want to make sure that the range of x used in this separate training matches what would be coming in from your end2end model. Perhaps applying some normalization function like sigmoid on the x layer is required to cast everything to a comparable range.

Curious if that works,