From https://gist.github.com/ericjang/1001afd374c2c3b7752545ce6d9ed349#file-gumbel-softmax-py-L27
y = tf.stop_gradient(y_hard - y) + y
Whoa, that’s so clever. I had to stare at that for ages before finally figuring that out. So cool . So, the result of this is:
- y is pure one-hot, in terms of value (since we add the soft y, and then subtract it again
- the gradients are those of soft y (since all the other terms in this expression have their gradient stripped)