Pytorch does not freeze the graph until a backward
call (or variables out of scope
https://discuss.pytorch.org/t/how-to-free-graph-manually/9255/2)
In the second case, so every operation like total_loss = total_loss + loss
add new nodes to the graph.
So in every iteration, a subgraph with the same structure (if python logic is the same) but different values is add to the graph.
The graph is freed every 64 iterations on the call to total_loss.backward()
1 Like