For my own understanding, I would like to explicitly see that the graph is freed after optimizer.step()
is called, rather than relying on an implicit indication like the error we get if we call backward()
again. What user-facing indications do we have that the graph is freed? I tried checking .grad_fn
but they aren’t set to None
. In other words, what can I print before and after backward()
to see plainly that the graph has been freed?
Most discussions about how the graph is freed are about how an error is raised if one calls loss.backward()
again after optimizer.step()
without retain_graph
, etc. My question is different.
f = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 1))
optimizer = torch.optim.Adam(f.parameters(), lr=0.1)
x = torch.randn(2, 10)
y = x @ torch.ones(10, 1)
for ii in range(10):
pred = f(x)
l = nn.MSELoss()(pred, y)
print(pred.grad_fn)
l.backward()
print(pred.grad_fn)
optimizer.step()
print(pred.grad_fn)
All the prints yield <AddmmBackward object at 0x7f0ee95450f0>