I am trying a policy network with gumbel-softmax provided by pytorch.

r_out = myRNNnetwork(x, h, c)

Policy = F.gumbel_softmax(r_out, temperature, True)

In the above implementation， r_out is the output from RNN which represents the variable before sampling. It’s a 1x2 float tensor like this: [-0.674, -0.722], and I noticed r_out [0] is always larger than r_out[1].

Then, I sampled policy with gumbel_softmax, and the output will be either [0, 1] or [1, 0] depending on the input signal.

Although r_out [0] is always larger than r_out[1], the network seems to really learn something meaningful (i.e. generate correct [0,1] or [1,0] for specific input x). This actually surprised me. So my first question is: Is it normal that r_out [0] is always larger than r_out[1] but policy is correct after gumbel-softmax sampling?

In addition, what is the correct way to perform inference or validation with a model trained like this? Should I still use gumbel-softmax during inference, which my worry is that it will introduce randomness? But if I just replaced gumbel-softmax sampling and simply do deterministic r_out.argmax(), the return is always fixed to [1, 0], which is still not right.

Could someone provide some guidance on this?