When is the better time to zero the gradients?

I found code like this:

opt.step()
opt.zero_grad()

and I also found code like this:

opt.zero_grad()
opt.step()

Which one would be more correct? Am I missing something obvious here?
I found this answer helpful but still is it wrong to zero the grads after the step(), since possible we don’t need them anymore after the step().

Hi,

The second one is most likely a bug.
Remember that .backward() accumulates gradients and so zero_grad() should be called such that you clear the old gradients before accumulating the new ones.
The second example here actually zero out the gradients just before using them. So the step won’t do much as it will always be given 0 gradients (it may still move due to regularization and momentum terms though).

In general, I would advice calling zero_grad() before backward() instead of after .step() so that you are sure the gradients are always cleared (especially in the first iteration of your loop).

2 Likes