Gradient scaling, reversal

I wonder about the best way how to implement gradient reversal, or in general gradient scaling (reversal is the special case of using factor -1 then).


Existing implementations:

Some questions on this code:

Fairseq just does ctx.scale = scale, while the other implementations use ctx.save_for_backward(input_, alpha_). What’s the difference? What is better?

Fairseq uses res = but the others do not. Why is this needed? What does it actually do? I did not found the documentation on

The other implementations check for ctx.needs_input_grad[0] in the backward pass but Fairseq does not. Is this not needed?