Gumbel-Softmax like function

Hello everyone,

I’m unsure if it’s appropriate to ask a mathematical question here, but I don’t have anywhere else to turn.

Let me explain my goal. Consider an input tensor like [1,2,3,4,5]. I aim to obtain a one-hot encoded vector of the argmax of this input tensor, which would result in [0,0,0,0,1].

The issue is that this process needs to be differentiable. In my research, I’ve looked into the Gumbel-Softmax and softargmax functions. From my understanding, Gumbel-Softmax does not guarantee an output of [0,0,0,0,1] as it follows the categorical distribution of the input tensor.

With the softargmax function, I obtained an approximate argmax index. However, after applying one-hot encoding, I ended up with NaN loss values. Is one-hot encoding not differentiable, or did I make a mistake somewhere?

To summarize:

  1. Are there any alternatives to Gumbel-Softmax for achieving a differentiable one-hot(argmax(input_tensor))?
  2. Is one-hot encoding non-differentiable?