Similar to this StackOverflow question pytorch - How to set gradients to Zero without optimizer? - Stack Overflow,
I want to reset a computational graph to be able to call loss.backward()
multiple times, without using an optimizer. How can I do this?
Take the example given in that post. The example has 2 variables, x
and t
, but I’m interested in zeroing all gradients.
import torch
x = torch.randn(3, requires_grad = True)
t = torch.randn(3, requires_grad = True)
y = x + t
z = y + y.flip(0)
z.backward(torch.tensor([1., 0., 0.]), retain_graph = True)
print(x.grad)
print(t.grad)
x.grad.data.zero_() # both gradients need to be set to zero
t.grad.data.zero_()
z.backward(torch.tensor([0., 1., 0.]), retain_graph = True)
print(x.grad)
print(t.grad)