Is it possible to mark part of the forward pass to only backpropagate the gradient but not to adjust weights?
In the following example code I have a Module that uses only one layer (one set of parameters) but it is used twice in the forward step. During the optimization I would expect the weights to be adjusted twice as well. If I want to only adjust the weights for one of the layer usages, what can I do?
import torch
class ExampleModel(torch.nn.Module):
def __init__(self, dim) -> None:
super(ExampleModel, self).__init__()
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x):
out1 = self.linear(x) # backprop gradients and adjust weights here
out2 = self.linear(out1) # only backprop gradients here
return out2
# Random input output data for this example
N, D = 64, 100
x = torch.randn(N, D)
y = torch.randn(N, D)
model = ExampleModel(D)
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters())
y_pred = model(x)
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
The following will not work since with torch.no_grad() no gradient at all is backpropagated:
def forward(self, x):
out1 = self.linear(x) # backprop gradients and adjust weights here
with torch.no_grad():
out2 = self.linear(out1) # only backprop gradients here
return out2
I can not simply exclude the parameters from the optimization since they need to be optimized in the first part (i.e. out1 = self.linear(x)).
For the same reason I can also not set a learning rate of 0 for these parameters.
What else can I do to achieve this?
I asked this question on Stackoverflow but based on other questions I found during my google search it seems that the pytorch community is not very active on Stackoverflow but is active on this forum. Therefore I repost my question here. See my Stackoverflow question here: https://stackoverflow.com/questions/62487201
What is the difference between torch.no_grad() and to wrap the line with .requires_grad_(False) -> .requires_grad_(True)?
That seems equivalent to each other.
With torch.no_grad() I get a runntime exception since there is no gradient:
def forward(self, x):
out1 = self.linear(x) # backprop gradients and adjust weights here
with torch.no_grad():
out2 = self.linear(out1) # only backprop gradients here
return out2
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
By wrapping it with .requires_grad_(False) -> .requires_grad_(True) it seems to run through but I don’t get why. For me that looks identical to the first solution.
def forward(self, x):
out1 = self.linear(x) # backprop gradients and adjust weights here
self.linear.requires_grad_(False)
out2 = self.linear(out1) # only backprop gradients here
self.linear.requires_grad_(True)
return out2
I have the same confusion with .detach. If I would detach it, then how should the gradient be backpropagated?
Maybe the confusion is on the terminology here but I would suspect no_grad, requires_grad(False) and detach to block gradient propagation. Maybe a feature request for a no_weight_adjustment / without_weight_adjustment parameter or environment makes sense?
How should the detach solution work? Could you give an example or add a link to a source where something similar is described?
torch.no_grad() is more aggressive, as it Autograd won’t track any operations in the block, so that out2 won’t have a valid .grad_fn and you thus won’t be able to call backward on the calculated loss. .requires_grad will disable the grad calculation for this particular parameter, while the output of the operation will still have a .grad_fn, if one input had previously a .grad_fn (out1 in your example).
Here is an example for all three workflows:
class ExampleModel(torch.nn.Module):
def __init__(self, dim) -> None:
super(ExampleModel, self).__init__()
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x, mode):
out1 = self.linear(x)
if mode == 'original':
out2 = self.linear(out1)
elif mode == 'detach':
out2 = F.linear(out1, self.linear.weight.detach(), self.linear.bias.detach())
elif mode == 'req_grad':
self.linear.requires_grad_(False)
out2 = self.linear(out1)
self.linear.requires_grad_(True)
else:
raise RuntimeError
return out2
torch.manual_seed(2809)
# Random input output data for this example
N, D = 64, 100
x = torch.randn(N, D)
y = torch.randn(N, D)
model = ExampleModel(D)
criterion = torch.nn.MSELoss(reduction='sum')
# Original
y_pred = model(x, mode='original')
loss = criterion(y_pred, y)
loss.backward()
weight_grad_ref = model.linear.weight.grad.clone()
bias_grad_ref = model.linear.bias.grad.clone()
model.zero_grad()
# Detach
y_pred = model(x, mode='detach')
loss = criterion(y_pred, y)
loss.backward()
weight_grad_detach = model.linear.weight.grad.clone()
bias_grad_detach = model.linear.bias.grad.clone()
model.zero_grad()
#Req_grad
y_pred = model(x, mode='req_grad')
loss = criterion(y_pred, y)
loss.backward()
weight_grad_req_grad = model.linear.weight.grad.clone()
bias_grad_req_grad = model.linear.bias.grad.clone()
model.zero_grad()
# Compare
print((weight_grad_ref - weight_grad_detach).abs().max())
print((weight_grad_detach - weight_grad_req_grad).abs().max())
As you can see, the “detach” and “requires_grad” approach yield the same output.
@ptrblck Thanks for the explanation and the extended example.
Okay, I get how the detach solution was meant here. Of course in my real world problem the model is not just a linear layer. I guess with a detach solution I would need to redefine my whole model with functions. I guess wrapping with requires_grad is easier then. Is there a downside to using requires_grad compared to functions (aka the detach solution)?
I still think there could be some improvement of pytorch here. The difference you mentioned between no_grad and requires_grad is not clear and I think a specific parameter or environment for the “propagate gradient but don’t adjust weights” use-case would be helpful. It would make this much more explicit (I prefer explicitness).
What is the right place to propose such a feature request? Do you see any chance for such a feature request to succeed?
You could create a feature request on GitHub and explain the shortcomings you are seeing as well as a proposal (if you have a good idea how it should look).
It depends on the request and if it would make all use cases more explicit or could yield ambiguous behavior in specific use cases.