Neither, you should use autograd.Function
, see docs.
By the way, register hook maybe a better choice if it could work. You should try to make sure backward
give the true grad.
examples of register_hook
and more by
https://discuss.pytorch.org/search?q=register%20hook