Pytorch tutorial indicates retain_graph
has to be True
when run backward twice or more:
if you even want to do the backward on some part of the graph twice, you need to pass in retain_graph = True during the first pass.
However, I found in some situations it may not be so, as shown in following snippet (use pyTorch1.0)
a = torch.rand(1,4, requires_grad=True)
b = a + 2
b.backward(torch.ones(1,4)) #Note there is no retain_graph=True
b.backward(torch.ones(1,4)) #there should be error!
print a.grad #tensor([[2., 2., 2., 2.]])
Strangely, if I change b = a + 2
in above codes to b = a**2
, then the second b.backward
would cause the expected RuntimeError
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Could any one explain? Thanks!