I tried .backward(retain_graph=True) twice
import torch
x = torch.ones(2,2,requires_grad=True)
y=x+2
z=y * y * 3
out=z.mean()
out.backward(retain_graph=True)
print(x.grad)
out.backward(retain_graph=True)
print(x.grad)
The first result of print(x.grad) is tensor([[4.5000, 4.5000], [4.5000, 4.5000]])
and i wondered what if out.backward is performed twice, so i did.
then second print is tensor([[9., 9.], [9., 9.]])
I can`t understand how this result came out
can you please explain ?