How to understand 'calling .backward() clears the computation graph'

I understand that calling backward will clear the computation graph and as a result, a second call of it will throw an exception. e.g.,

x = Variable(torch.Tensor(5, 3), requires_grad=True)
y = Variable(torch.Tensor(5, 3), requires_grad=True)
z = torch.mm(x, torch.transpose(y, 0, 1))
z.backward(torch.ones(5, 5), retain_graph=False)
z.backward(torch.ones(5, 5), retain_graph=False)) # >> Throws RuntimeError

However, this seems to work differently for some operators, e.g., elementwise addition, e.g.,

x = Variable(torch.Tensor(5, 3), requires_grad=True)
y = Variable(torch.Tensor(5, 3), requires_grad=True)
z = x + y
z.backward(torch.ones(5, 3), retain_graph=False)
# You can call it multiple times and the grad will just accumulate
z.backward(torch.ones(5, 3), retain_graph=False) 

I am using pytorch ‘0.2.0_4’, can some one explain why the second case works differently? Thanks!

See the answer here:

If retain_graph=False, intermediate outputs needed for the backwards computation are freed. If there are no saved intermediate outputs, (like in the case of addition) then subsequent calls may still work, but you should not rely on that behavior.

1 Like

Cool. That makes sense. Thank you for the clarification.