Re-register a removed backward hook within the backward function

Hi!

I wonder if there is a way to re-register the same backward hook handle after we have removed it? Specifically, my layer’s custom backward function needs to use the Jacobian information of the output, and one way of doing this is:

class MyLayer(nn.Module):
        def __init__(self, ...):
                ...  
                self.hook = None

	def forward(x):
		x.require_grad_()
		y = self.net(x)

		def backward_hook(grad):
			if self.hook is not None:
				self.hook.remove()
			return autograd.grad(y, x, grad)[0]

		self.hook = y.register_hook(backward_hook)
		return y

In the code above, the self.hook.remove() is needed because otherwise the autograd.grad(y,x,...) would recursively call the backward_hook function. When I have one loss per training iteration and backprop through it, everything works fine.

However, one obvious drawback of the implementation above is that if I have two losses, loss1(y) and loss2(y) that both depends on the output y, and if I do

loss1(y).backward()
loss2(y).backward()

the second loss’ (i.e., loss2) backward pass will not go through the backward_hook function because I have already removed the handle when I call loss1(y).backward(). And that is not an expected behavior. (For example, if I do gradcheck, this problem will occur.)

One way to work around this issue is by cloning and detaching the x and y in the layer; i.e.,

class MyLayer(nn.Module):
	def forward(x):
		y = self.net(x)

		xc = x.clone().detach().requires_grad_()
		yc = self.net(xc)
		def backward_hook(grad):
			return autograd.grad(yc, xc, grad)[0]

		y.register_hook(backward_hook)
		return y

But this requires calling self.net(xc) again, which adds computational and memory costs. I wonder if there’s a way that we can re-register the hook after we have already removed it (as in the first code block), so that the custom backward pass works even if we call it multiple times, and we don’t need to clone anything?

Thanks!

Hi,

I don’t think I can thing of a “safe” way to do that. Mainly because you will have to change the state of the Module while the backward is running to avoid the infinite recursion.
So calling two backward concurrently will always do something bad.

Assuming that you never do concurrent backwards, you can set a flag on the Module to know if this is the “top” backward or not:

class MyLayer(nn.Module):
        def __init__(self, ...):
                ...  
                self.in_backward = False

	def forward(x):
		x.require_grad_()
		y = self.net(x)

		def backward_hook(grad):
			if self.in_backward:
                # If we're already in backward, do nothing
				return grad
			return autograd.grad(y, x, grad)[0]

		y.register_hook(backward_hook)
		return y