Why do we need to set the gradients manually to zero in pytorch?

Pytorch does not freeze the graph until a backward call (or variables out of scope
https://discuss.pytorch.org/t/how-to-free-graph-manually/9255/2)
In the second case, so every operation like total_loss = total_loss + loss add new nodes to the graph.
So in every iteration, a subgraph with the same structure (if python logic is the same) but different values is add to the graph.
The graph is freed every 64 iterations on the call to total_loss.backward()

1 Like