PyTorch Hooks - How to Conditionally Modify Elements of Gradient Based on Path?

I’m having trouble figuring out how to implement something I want in PyTorch: path-conditional gradient backpropagation.

For simplicity, suppose I have data with shape (batch size, input_dimension) and I have a simple network that outputs a scalar sum of two affine transformations of the input i.e.

linear1 = nn.Linear(in_features=input_dimension, out_features=1)
linear2 = nn.Linear(in_features=input_dimension, out_features=1)
y = linear1(x) + linear2(x)
loss = torch.mean((y - y_target) ** 2)

During backprop, I’d like to update linear1's parameters using only elements in the batch where $y < 0$ and update linear2's parameters using only elements in the batch where $y > 0$.

How can I implement this?

I’ve tried register_backward_hook, but if I’m understand the functionality correctly, by the time the registered function is called, the gradient of the error with respect to the parameters has already been calculated. I tried register_hook, but this doesn’t permit me to conditionally mask the gradient dL/dy depending on which linear layer has backward() called next.

@albanD @richard you guys were helpful before. Any suggestions?

I think the simplest explanation of the problem is the following: Suppose I compute the gradient with respect to a tensor y i.e. dL/dy. How can I create two modified versions of dL/dy and route them to different subgraphs of the computational graph during backprop?

Hi,

EDIT: The gt/lt use below is not accurate. Also this is pseudo code so there might be typos :smiley:

The “right” way to do this would be to actually write down the loss you want:

linear1 = nn.Linear(in_features=input_dimension, out_features=1)
linear2 = nn.Linear(in_features=input_dimension, out_features=1)
y = linear1(x) + linear2(x)

mask_pos = y.gt(0)
loss_lin1 = torch.mean((y[mask_pos] - y_target[mask_pos]) ** 2)

mask_neg = y.lt(0)
loss_lin2 = torch.mean((y[mask_neg] - y_target[mask_neg]) ** 2)

linear1.zero_grad()
loss_lin1.backward(retain_graph=True)
lin1_opt.step()

linear2.zero_grad()
loss_lin2.backward()
lin2_opt.step()

Another version that only does one backward (compute loss that only involve one of the layers):

linear1 = nn.Linear(in_features=input_dimension, out_features=1)
linear2 = nn.Linear(in_features=input_dimension, out_features=1)

y_lin1 = linear1(x) + linear2(x).detach()
mask_pos = y_lin1.gt(0)
loss_lin1 = torch.mean((y_lin1[mask_pos] - y_target[mask_pos]) ** 2)

y_lin2 = linear1(x).detach() + linear2(x)
mask_neg = y_lin2.lt(0)
loss_lin2 = torch.mean((y_lin2[mask_neg] - y_target[mask_neg]) ** 2)

loss = loss_lin1 + loss_lin2

opt.zero_grad()
loss.backward()
opt.step()

The more dangerous version with hooks would be like:

linear1 = nn.Linear(in_features=input_dimension, out_features=1)
linear2 = nn.Linear(in_features=input_dimension, out_features=1)
out_lin1 = linear1(x)
out_lin2 = linear2(x)
y = out_lin1 + out_lin2
loss = torch.mean((y - y_target) ** 2)

def lin1_hook(grad):
  return grad * y.lt(0).float()
def lin2_hook(grad):
  return grad * y.gt(0).float()

out_lin1.register_hook(lin1_hook)
out_lin2.register_hook(lin2_hook)

opt.zero_grad()
loss.backward()
opt.step()

Which one of these you want will depends on your actual more complex use case.
The first one is going to be the clearer event though it might not be the fastests.
The last time will be the most efficient but you need to be careful whenever you change stuff as you’re cheating the autograd engine not to compute the “real” gradients associated with the loss you computed.

Wow this is fantastic. Give me a sec to take a look and see if it’ll work for my use case. No worries about the pseudo-code!

Ok I see an immediate problem. This works for a single “split” gradient, y, but I want to do this all over my computational graph. By split gradient, I mean splitting the gradient with respect to a tensor into >= 2 versions and routing different versions to different parts of the subgraph.

Is there any immediate easy solution?

I think I was misunderstanding how register_hook works. It doesn’t modify the gradient in place - you can replace it with a new tensor.

You should never modify the gradient passed to the hook!
But the hook can return a Tensor that will be use instead of the current gradient.

So the hook can be used to replace the gradient that will flow back a given branch.

1 Like

Thank you again for this! Why do you say using register_hook is the most dangerous? Isn’t my use case one of the reasons why hooks exist?

@albanD I clarified my original problem in my head. I’m afraid I misrepresented it. Suppose I have the following forward graph and I also want to compute gradients with respect to x (since x is produced by earlier operations):

l = torch.mean((y - y_target) ** 2)
y = W x + b

When calculating dL/dW = dL/dy dy/dW , I want dL/dy to have its elements conditionally set to zero under a specific condition (e.g. y is positive). But when calculating dL/dx = dL/dy dy/dx , I want dL/dy to have its elements unchanged. In other words, dL/dy should have different values passed backwards in the graph to different inputs.

It doesn’t seem like the register_hook approach will work. Is this correct? If so, what’s my next best option?

This is a good example where the register_hook would have been sneakily wrong :smiley:

If you want to do that, I’m afraid that you will need two backward passes anyway. So the first solution at the top would be the best.
That way you make sure that each forward computes exactly what should be differentiated and you will get the gradients you want.

1 Like