Is it possible to mark part of the forward pass to only backpropagate the gradient but not to adjust weights?

In the following example code I have a Module that uses only one layer (one set of parameters) but it is used twice in the forward step. During the optimization I would expect the weights to be adjusted twice as well. If I want to only adjust the weights for one of the layer usages, what can I do?

import torch
class ExampleModel(torch.nn.Module):
def __init__(self, dim) -> None:
super(ExampleModel, self).__init__()
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x):
out1 = self.linear(x) # backprop gradients and adjust weights here
out2 = self.linear(out1) # only backprop gradients here
return out2
# Random input output data for this example
N, D = 64, 100
x = torch.randn(N, D)
y = torch.randn(N, D)
model = ExampleModel(D)
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters())
y_pred = model(x)
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

The following will not work since with torch.no_grad() no gradient at all is backpropagated:

def forward(self, x):
out1 = self.linear(x) # backprop gradients and adjust weights here
with torch.no_grad():
out2 = self.linear(out1) # only backprop gradients here
return out2

I can not simply exclude the parameters from the optimization since they need to be optimized in the first part (i.e. out1 = self.linear(x)).
For the same reason I can also not set a learning rate of 0 for these parameters.

What else can I do to achieve this?

I asked this question on Stackoverflow but based on other questions I found during my google search it seems that the pytorch community is not very active on Stackoverflow but is active on this forum. Therefore I repost my question here. See my Stackoverflow question here: https://stackoverflow.com/questions/62487201

What is the difference between torch.no_grad() and to wrap the line with .requires_grad_(False) -> .requires_grad_(True)?
That seems equivalent to each other.

With torch.no_grad() I get a runntime exception since there is no gradient:

def forward(self, x):
out1 = self.linear(x) # backprop gradients and adjust weights here
with torch.no_grad():
out2 = self.linear(out1) # only backprop gradients here
return out2

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

By wrapping it with .requires_grad_(False) -> .requires_grad_(True) it seems to run through but I donâ€™t get why. For me that looks identical to the first solution.

def forward(self, x):
out1 = self.linear(x) # backprop gradients and adjust weights here
self.linear.requires_grad_(False)
out2 = self.linear(out1) # only backprop gradients here
self.linear.requires_grad_(True)
return out2

I have the same confusion with .detach. If I would detach it, then how should the gradient be backpropagated?

Maybe the confusion is on the terminology here but I would suspect no_grad, requires_grad(False) and detach to block gradient propagation. Maybe a feature request for a no_weight_adjustment / without_weight_adjustment parameter or environment makes sense?

How should the detach solution work? Could you give an example or add a link to a source where something similar is described?

torch.no_grad() is more aggressive, as it Autograd wonâ€™t track any operations in the block, so that out2 wonâ€™t have a valid .grad_fn and you thus wonâ€™t be able to call backward on the calculated loss. .requires_grad will disable the grad calculation for this particular parameter, while the output of the operation will still have a .grad_fn, if one input had previously a .grad_fn (out1 in your example).

Here is an example for all three workflows:

class ExampleModel(torch.nn.Module):
def __init__(self, dim) -> None:
super(ExampleModel, self).__init__()
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x, mode):
out1 = self.linear(x)
if mode == 'original':
out2 = self.linear(out1)
elif mode == 'detach':
out2 = F.linear(out1, self.linear.weight.detach(), self.linear.bias.detach())
elif mode == 'req_grad':
self.linear.requires_grad_(False)
out2 = self.linear(out1)
self.linear.requires_grad_(True)
else:
raise RuntimeError
return out2
torch.manual_seed(2809)
# Random input output data for this example
N, D = 64, 100
x = torch.randn(N, D)
y = torch.randn(N, D)
model = ExampleModel(D)
criterion = torch.nn.MSELoss(reduction='sum')
# Original
y_pred = model(x, mode='original')
loss = criterion(y_pred, y)
loss.backward()
weight_grad_ref = model.linear.weight.grad.clone()
bias_grad_ref = model.linear.bias.grad.clone()
model.zero_grad()
# Detach
y_pred = model(x, mode='detach')
loss = criterion(y_pred, y)
loss.backward()
weight_grad_detach = model.linear.weight.grad.clone()
bias_grad_detach = model.linear.bias.grad.clone()
model.zero_grad()
#Req_grad
y_pred = model(x, mode='req_grad')
loss = criterion(y_pred, y)
loss.backward()
weight_grad_req_grad = model.linear.weight.grad.clone()
bias_grad_req_grad = model.linear.bias.grad.clone()
model.zero_grad()
# Compare
print((weight_grad_ref - weight_grad_detach).abs().max())
print((weight_grad_detach - weight_grad_req_grad).abs().max())

As you can see, the â€śdetachâ€ť and â€śrequires_gradâ€ť approach yield the same output.

@ptrblck Thanks for the explanation and the extended example.

Okay, I get how the detach solution was meant here. Of course in my real world problem the model is not just a linear layer. I guess with a detach solution I would need to redefine my whole model with functions. I guess wrapping with requires_grad is easier then. Is there a downside to using requires_grad compared to functions (aka the detach solution)?

I still think there could be some improvement of pytorch here. The difference you mentioned between no_grad and requires_grad is not clear and I think a specific parameter or environment for the â€śpropagate gradient but donâ€™t adjust weightsâ€ť use-case would be helpful. It would make this much more explicit (I prefer explicitness).
What is the right place to propose such a feature request? Do you see any chance for such a feature request to succeed?

You could create a feature request on GitHub and explain the shortcomings you are seeing as well as a proposal (if you have a good idea how it should look).

It depends on the request and if it would make all use cases more explicit or could yield ambiguous behavior in specific use cases.