I understand that calling backward will clear the computation graph and as a result, a second call of it will throw an exception. e.g.,
x = Variable(torch.Tensor(5, 3), requires_grad=True)
y = Variable(torch.Tensor(5, 3), requires_grad=True)
z = torch.mm(x, torch.transpose(y, 0, 1))
z.backward(torch.ones(5, 5), retain_graph=False)
z.backward(torch.ones(5, 5), retain_graph=False)) # >> Throws RuntimeError
However, this seems to work differently for some operators, e.g., elementwise addition, e.g.,
x = Variable(torch.Tensor(5, 3), requires_grad=True)
y = Variable(torch.Tensor(5, 3), requires_grad=True)
z = x + y
z.backward(torch.ones(5, 3), retain_graph=False)
# You can call it multiple times and the grad will just accumulate
z.backward(torch.ones(5, 3), retain_graph=False)
I am using pytorch ‘0.2.0_4’, can some one explain why the second case works differently? Thanks!
If retain_graph=False, intermediate outputs needed for the backwards computation are freed. If there are no saved intermediate outputs, (like in the case of addition) then subsequent calls may still work, but you should not rely on that behavior.