I was thinking of an optimization for gradient computation where we don’t propagate the gradients to the children, ie. detach the current node from the computation graph if the node’s gradient is less than a given threshold.
I tested this by having a tensor v and u defined as v**2 and calling u.sum().backward(). I registered a hook on u which made its gradients as None but I found that gradient computation was still being done on v.
Hence I was wondering how to do this. Of course the naive way would be to register a hook and zero out the gradients if it is less than a threshold, but while it would have the same effect, I wouldn’t get the desired speedup, defeating the purpose.
Returning “None” is not working as we consider that as “not changing the gradient”. This is because in python, any function that doesn’t return anything returns None.
You can return torch.zeros_like(grad) to have gradients full of 0s.
I’m afraid there is no way to stop the backprop in the middle of the execution no
You can open an issue on github if you want to discuss what we could be doing there!