Hi,
I want to first make a prediction with r_index = x.argmax()
. The argmax()
result r_index
is the index of different rates. For example, there are 3 predefined rates: rates = torch.Tensor([0.1, 0.5, 0.7])
.
Then I need this index for further prediction. output = rates[r_index] * ...
As argmax
is not differentiable. How to make it work in pytorch? I want to do it in an end-to-end model.
Thanks a lot if anyone can help!
Hi Steve,
Could you brute force this by training a model of the type:
Linear(len(x), len_intermediary)
Linear(len_intermediary, 1)
And train it to output the desired result (e.g. 0.1 if argmax(x) == 0, 0.5 if argmax(x) == 1, and so forth). Then simply freeze these parameters and load them into your end2end model in the appropriate place.
In other words, perhaps the transformation you want – from argmax to specific value – can be learned by a series of fully connected layers. This would allow you to piggyback on the differentiable tools already available.
You probably want to make sure that the range of x used in this separate training matches what would be coming in from your end2end model. Perhaps applying some normalization function like sigmoid on the x layer is required to cast everything to a comparable range.
Curious if that works,
Andrei