Hi folks,
I am wondering if there is a good way to recycle the gradient tensor when calling backward(). Here is an example.
import torch
x = torch.rand(4,4)
w = torch.nn.Linear(4,4)
with torch.set_grad_enabled(True):
y = w(x)
dy = torch.rand(4,4)
y.backward(dy)
I am hoping to recycle dy
in the middle of y.backward()
as soon as it is no longer needed for the chain rule. How can I do that?
Thanks!