Calculating gradients only on some part of the graph

Hi, assume we have one input x and we calculate:

x’ = conv_layer(x)
x’’ = conv_layer(x’)
loss2 = criterion(x’’)

and I want to influence the conv_layer parameters only with respect to the gradients graph of x’ (only the path (x—>x’) so the gradient on the path x’ —> x’’ will not be aggregated and influence my optimizer.step action.

The trivial solutions are maybe to create a layer that overrides backward and in the backward path of x’ → x’’ zero the gradients manually instead of aggregating them (or maybe not zero, but substract the added grad value, or just make it the identity backward without any changes to current grad, make it neutral generally) and make it regular on the path I do want influence my optimizer.step.

Another similar solution is maybe to use backward hook to take the gradients from the connection point of the grad paths of x → x’ and x’ → x’’ and then also zero the grads (not sure if I can zero only the part of the unwanted path) and then x’.backward(hooked_grads_from_the_start_of_path_x’->x’’).

Does pytorch supplies some other move convenient and easy way to do so ?
Thank you in advance.

You could set the requires_grad attribute of the layer before the second forward pass to False and reset it to True afterwards or alternatively could use the functional API with detached parameters:

lin = nn.Linear(1, 1, bias=False)
x = torch.randn(1, 1)

out = lin(x)
for param in lin.parameters():
    param.requires_grad = False
out = lin(out)
for param in lin.parameters():
    param.requires_grad = True
> tensor([[0.2955]])

out = lin(x)
out = F.linear(out, lin.weight.detach())
> tensor([[0.2955]])
1 Like

That is great @ptrblck !
The first approach as much as I can tell will work in the case where instead of nn.Linear I have a whole model like VGG.

The second could become nasty, couldn’t it ? I will need through all the model forward path and do it manually for each forward of each layer, am I ?

Or could it be done only by:

x = input
model = vgg()
out = model(x)
out = model(out, model.parameters().detach())

About the for loop which disables requires_grad, for a whole model it could be time consuming, is there a way to disable all the parameters requires_grad to false at once (with.no_grad() will set off also the requires_grad of the result, which is undesirable) ?

I think the proposed solution doesn’t take into account one thing, correct if I am wrong (although I’m not sure if I want to take that one thing into account in the first place, hopefully your answer will help me figure it out).

when I compute in the regular case loss2’’.backward(), actually I compute loss2.backward(1.0), then when x’.backward is called in the backward path of the previous (D_loss2 / D_x’), and what is actually called is x’.backward(D_loss2 / D_x’).
I think that if I put requires_grad=False as you suggested actually loss2.backward(1) will backpropogated 1 backwards as the requires_grad=False, and then in the x’.backward(1.0), it will be called with 1.0 and not with the gradients of the second path.

I think the behavior that I want is from one hand that the second path not affect directly my network, but I still want the gradients (like in LSTM usage of backward(previous_gradients), I hope it is explained clearly.

No, that shouldn’t be the case and you could verify it by either calling backward() on the first output and compare the gradients in lin or by checking the intermediate gradient on the first output manually:

out = lin(x)
out_ = F.linear(out, lin.weight.detach())