Modifying backward for all modules in the model without rewriting it for all

So lets say I want to mess with the computed gradients at each node in the computational graph before passing it to the previous node. for example say I want to square all computed gradients before passing them back further. yes, I want to mess with the chain rule and backpropagation. But I’d like to know how can i do it without having to write a new backward for every element in my model. what would be the shortest path to accomplish this?

Hi @aynmeme you might find backward hooks useful - docs: Module — PyTorch 1.10 documentation