I wonder about the best way how to implement gradient reversal, or in general gradient scaling (reversal is the special case of using factor -1 then).
Related:
Existing implementations:
Some questions on this code:
Fairseq just does ctx.scale = scale
, while the other implementations use ctx.save_for_backward(input_, alpha_)
. What’s the difference? What is better?
Fairseq uses res = x.new(x)
but the others do not. Why is this needed? What does it actually do? I did not found the documentation on Tensor.new
.
The other implementations check for ctx.needs_input_grad[0]
in the backward pass but Fairseq does not. Is this not needed?