I don’t want to accumulate (which is what happens when I call backwards multiple time in the same computation graph). I want to extract the gradients for params wrt to loss
and use those gradients (call them dl_dw
) for later computations of J
and dJ_dw
. The way I’m doing it is not right because I can’t even call backward on J
because when I try to zero out the gradients before calling backward on J
pytorch halts on me due to illegal in-place operations.
However, I noticed that pytorch doesn’t allow me to do because it warns me about (not really (illegal)) in-place operations in my case. Is there a way to extract intermediate gradients without pytorch halting without my permission?
In case you need it, see new example code I cooked up for this
import torch
from torchviz import make_dot
x = torch.ones(10, requires_grad=True)
weights = {'x':x}
y = x**2
z = x**3
l = (x-2).sum()
l.backward()
g_x = x.grad
#g_x.requires_grad = True ## Adds it to the computation graph!
print(f'g_x = x.grad = {x.grad}\n')
x.zero_()
#weights['g_x'] = g_x
#print(f'weights = {weights}\n')
J = (y + z + g_x).sum()
J.backward()
print(f'g_x = x.grad = {x.grad}\n')
make_dot(J,params=weights)
Error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-27-c4086df8154e> in <module>
20 print(f'g_x = x.grad = {x.grad}\n')
21
---> 22 x.zero_()
23
24 #weights['g_x'] = g_x
RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.