Hi!
I wonder if there is a way to re-register the same backward hook handle after we have removed it? Specifically, my layer’s custom backward function needs to use the Jacobian information of the output, and one way of doing this is:
class MyLayer(nn.Module):
def __init__(self, ...):
...
self.hook = None
def forward(x):
x.require_grad_()
y = self.net(x)
def backward_hook(grad):
if self.hook is not None:
self.hook.remove()
return autograd.grad(y, x, grad)[0]
self.hook = y.register_hook(backward_hook)
return y
In the code above, the self.hook.remove()
is needed because otherwise the autograd.grad(y,x,...)
would recursively call the backward_hook
function. When I have one loss per training iteration and backprop through it, everything works fine.
However, one obvious drawback of the implementation above is that if I have two losses, loss1(y)
and loss2(y)
that both depends on the output y
, and if I do
loss1(y).backward()
loss2(y).backward()
the second loss’ (i.e., loss2
) backward pass will not go through the backward_hook
function because I have already removed the handle when I call loss1(y).backward()
. And that is not an expected behavior. (For example, if I do gradcheck
, this problem will occur.)
One way to work around this issue is by cloning and detaching the x
and y
in the layer; i.e.,
class MyLayer(nn.Module):
def forward(x):
y = self.net(x)
xc = x.clone().detach().requires_grad_()
yc = self.net(xc)
def backward_hook(grad):
return autograd.grad(yc, xc, grad)[0]
y.register_hook(backward_hook)
return y
But this requires calling self.net(xc)
again, which adds computational and memory costs. I wonder if there’s a way that we can re-register the hook after we have already removed it (as in the first code block), so that the custom backward pass works even if we call it multiple times, and we don’t need to clone anything?
Thanks!