Assume there is a non-differentiable
nn.Parameter in an equation whose gradient needs to be estimated using a straight-through estimator (STE) before the parameter can be updated. For example,
y = 0.5(|x| - |x - alpha| + alpha)
y_q = round(y * (2 ** k - 1) / alpha) * alpha / (2 ** k - 1).
In this equation,
alpha is a trainable parameter and the derivative of
alpha needs to be estimated using an STE.
How can I define the approximated gradient and use it in an
optimizer to update the parameter along with the rest of parameters in the network?
Note that it should likely be
I like the trick of
y_q_diffable = y + (y_q - y).detach(). (In fact, I once proposed to give a lightning talk just on this line of code and whereit is useful.)
I always like to credit
@hughperkins for sharing the trick here on the forums when he seen it in a paper and he knows a ton references for applications, too.
Yes, you’re right. It should be
Do you mind explaining this trick a little bit more in detail?
y_q_diffable is y_q ( = y + y_q - y) for the forward. But during backwards, the gradients propagate as if y_q_diffable were y.
y = tf.stop_gradient(y_hard - y) + y
Whoa, that’s so clever. I had to stare at that for ages before finally figuring that out. So cool . So, the result of this is:
y is pure one-hot, in terms of value (since we add the soft y, and then subtract it again
the gradients are those of soft y (since all the other terms in this expression have their gradient stripped)
@tom I am curious, did you prepare that talk in the end? Would you mind sharing some references for applications?
No, I didn’t, but my favourite application that I use in my autograd course is to emulate quantization aware training with it. The course are not freely available, but the particular example is also included in the
ACDL “Advanced introduction to PyTorch” talk of which I published the slides. I don’t know of any video recording and there wasn’t enough interest to re-record it back then.