Inherit from autograd.Function

This is not directly relevant to the issue you’re seeing, but it’s important to note:
Certain parts of torch.autograd either currently do or will in the future assume that the gradients returned by Functions are correct (i.e., equal to the mathematical derivative). If you want to do things that violate this assumption, that’s fine – but they should be implemented as gradient hooks (var.register_hook) which can arbitrarily modify gradient values, not as Functions.

3 Likes