I’m having trouble figuring out how to implement something I want in PyTorch: path-conditional gradient backpropagation.
For simplicity, suppose I have data with shape (batch size, input_dimension) and I have a simple network that outputs a scalar sum of two affine transformations of the input i.e.
linear1 = nn.Linear(in_features=input_dimension, out_features=1)
linear2 = nn.Linear(in_features=input_dimension, out_features=1)
y = linear1(x) + linear2(x)
loss = torch.mean((y - y_target) ** 2)
During backprop, I’d like to update linear1's parameters using only elements in the batch where $y < 0$ and update linear2's parameters using only elements in the batch where $y > 0$.
How can I implement this?
I’ve tried register_backward_hook, but if I’m understand the functionality correctly, by the time the registered function is called, the gradient of the error with respect to the parameters has already been calculated. I tried register_hook, but this doesn’t permit me to conditionally mask the gradient dL/dy depending on which linear layer has backward() called next.
I think the simplest explanation of the problem is the following: Suppose I compute the gradient with respect to a tensor y i.e. dL/dy. How can I create two modified versions of dL/dy and route them to different subgraphs of the computational graph during backprop?
Which one of these you want will depends on your actual more complex use case.
The first one is going to be the clearer event though it might not be the fastests.
The last time will be the most efficient but you need to be careful whenever you change stuff as you’re cheating the autograd engine not to compute the “real” gradients associated with the loss you computed.
Ok I see an immediate problem. This works for a single “split” gradient, y, but I want to do this all over my computational graph. By split gradient, I mean splitting the gradient with respect to a tensor into >= 2 versions and routing different versions to different parts of the subgraph.
@albanD I clarified my original problem in my head. I’m afraid I misrepresented it. Suppose I have the following forward graph and I also want to compute gradients with respect to x (since x is produced by earlier operations):
l = torch.mean((y - y_target) ** 2)
y = W x + b
When calculating dL/dW = dL/dy dy/dW , I want dL/dy to have its elements conditionally set to zero under a specific condition (e.g. y is positive). But when calculating dL/dx = dL/dy dy/dx , I want dL/dy to have its elements unchanged. In other words, dL/dy should have different values passed backwards in the graph to different inputs.
It doesn’t seem like the register_hook approach will work. Is this correct? If so, what’s my next best option?
This is a good example where the register_hook would have been sneakily wrong
If you want to do that, I’m afraid that you will need two backward passes anyway. So the first solution at the top would be the best.
That way you make sure that each forward computes exactly what should be differentiated and you will get the gradients you want.