Stop gradients (for ST gumbel softmax)

From https://gist.github.com/ericjang/1001afd374c2c3b7752545ce6d9ed349#file-gumbel-softmax-py-L27

y = tf.stop_gradient(y_hard - y) + y

Whoa, that’s so clever. I had to stare at that for ages before finally figuring that out. So cool :slight_smile: . So, the result of this is:

  • y is pure one-hot, in terms of value (since we add the soft y, and then subtract it again
  • the gradients are those of soft y (since all the other terms in this expression have their gradient stripped)
7 Likes