Skip gradient propagation if gradients less than a threshold

Hi,

I was thinking of an optimization for gradient computation where we don’t propagate the gradients to the children, ie. detach the current node from the computation graph if the node’s gradient is less than a given threshold.

I tested this by having a tensor v and u defined as v**2 and calling u.sum().backward(). I registered a hook on u which made its gradients as None but I found that gradient computation was still being done on v.

Hence I was wondering how to do this. Of course the naive way would be to register a hook and zero out the gradients if it is less than a threshold, but while it would have the same effect, I wouldn’t get the desired speedup, defeating the purpose.

What is the purpose of doing this and what problem is solved?

I don’t know if your idea to “stop” the backward pass is currently possible, but @albanD might know if any utils could be helpful.

Returning “None” is not working as we consider that as “not changing the gradient”. This is because in python, any function that doesn’t return anything returns None.
You can return torch.zeros_like(grad) to have gradients full of 0s.

But I’m guessing I won’t get the speedup right?

I’m afraid there is no way to stop the backprop in the middle of the execution no :confused:
You can open an issue on github if you want to discuss what we could be doing there!