How to use a layer with gradient but without weight adjustment?

Is it possible to mark part of the forward pass to only backpropagate the gradient but not to adjust weights?

In the following example code I have a Module that uses only one layer (one set of parameters) but it is used twice in the forward step. During the optimization I would expect the weights to be adjusted twice as well. If I want to only adjust the weights for one of the layer usages, what can I do?

import torch
    
class ExampleModel(torch.nn.Module):
    
    def __init__(self, dim) -> None:
        super(ExampleModel, self).__init__()
        self.linear = torch.nn.Linear(dim, dim)
    
    def forward(self, x):
        out1 = self.linear(x)  # backprop gradients and adjust weights here
        out2 = self.linear(out1)  # only backprop gradients here
        return out2
    
    
# Random input output data for this example
N, D = 64, 100
x = torch.randn(N, D)
y = torch.randn(N, D)
    
model = ExampleModel(D)
    
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters())
    
y_pred = model(x)
loss = criterion(y_pred, y)
    
optimizer.zero_grad()
loss.backward()
optimizer.step()

The following will not work since with torch.no_grad() no gradient at all is backpropagated:

def forward(self, x):
    out1 = self.linear(x)  # backprop gradients and adjust weights here
    with torch.no_grad():
        out2 = self.linear(out1)  # only backprop gradients here
    return out2

I can not simply exclude the parameters from the optimization since they need to be optimized in the first part (i.e. out1 = self.linear(x)).
For the same reason I can also not set a learning rate of 0 for these parameters.

What else can I do to achieve this?

I asked this question on Stackoverflow but based on other questions I found during my google search it seems that the pytorch community is not very active on Stackoverflow but is active on this forum. Therefore I repost my question here. See my Stackoverflow question here: https://stackoverflow.com/questions/62487201

The posted answer on StackOverflow seems to be valid.
The second approach using detach() might be easier with the functional API.

@ptrblck That is confusing to me.

What is the difference between torch.no_grad() and to wrap the line with .requires_grad_(False) -> .requires_grad_(True)?
That seems equivalent to each other.

With torch.no_grad() I get a runntime exception since there is no gradient:

def forward(self, x):
    out1 = self.linear(x)  # backprop gradients and adjust weights here
    with torch.no_grad():
        out2 = self.linear(out1)  # only backprop gradients here
    return out2

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

By wrapping it with .requires_grad_(False) -> .requires_grad_(True) it seems to run through but I don’t get why. For me that looks identical to the first solution.

def forward(self, x):
    out1 = self.linear(x)  # backprop gradients and adjust weights here
    self.linear.requires_grad_(False)
    out2 = self.linear(out1)  # only backprop gradients here
    self.linear.requires_grad_(True)
    return out2

I have the same confusion with .detach. If I would detach it, then how should the gradient be backpropagated?

Maybe the confusion is on the terminology here but I would suspect no_grad, requires_grad(False) and detach to block gradient propagation. Maybe a feature request for a no_weight_adjustment / without_weight_adjustment parameter or environment makes sense?

How should the detach solution work? Could you give an example or add a link to a source where something similar is described?

torch.no_grad() is more aggressive, as it Autograd won’t track any operations in the block, so that out2 won’t have a valid .grad_fn and you thus won’t be able to call backward on the calculated loss.
.requires_grad will disable the grad calculation for this particular parameter, while the output of the operation will still have a .grad_fn, if one input had previously a .grad_fn (out1 in your example).

Here is an example for all three workflows:

class ExampleModel(torch.nn.Module):
    
    def __init__(self, dim) -> None:
        super(ExampleModel, self).__init__()
        self.linear = torch.nn.Linear(dim, dim)
    
    def forward(self, x, mode):
        out1 = self.linear(x)  
        if mode == 'original':
            out2 = self.linear(out1)
        elif mode == 'detach':
            out2 = F.linear(out1, self.linear.weight.detach(), self.linear.bias.detach())
        elif mode == 'req_grad':
            self.linear.requires_grad_(False)
            out2 = self.linear(out1)
            self.linear.requires_grad_(True)
        else:
            raise RuntimeError
        return out2
    
torch.manual_seed(2809)
# Random input output data for this example
N, D = 64, 100
x = torch.randn(N, D)
y = torch.randn(N, D)
    
model = ExampleModel(D)
criterion = torch.nn.MSELoss(reduction='sum')
    
# Original
y_pred = model(x, mode='original')
loss = criterion(y_pred, y)
loss.backward()

weight_grad_ref = model.linear.weight.grad.clone()
bias_grad_ref = model.linear.bias.grad.clone()
model.zero_grad()

# Detach
y_pred = model(x, mode='detach')
loss = criterion(y_pred, y)
loss.backward()

weight_grad_detach = model.linear.weight.grad.clone()
bias_grad_detach = model.linear.bias.grad.clone()
model.zero_grad()

#Req_grad
y_pred = model(x, mode='req_grad')
loss = criterion(y_pred, y)
loss.backward()

weight_grad_req_grad = model.linear.weight.grad.clone()
bias_grad_req_grad = model.linear.bias.grad.clone()
model.zero_grad()

# Compare
print((weight_grad_ref - weight_grad_detach).abs().max())
print((weight_grad_detach - weight_grad_req_grad).abs().max())

As you can see, the “detach” and “requires_grad” approach yield the same output.

2 Likes

@ptrblck Thanks for the explanation and the extended example.

Okay, I get how the detach solution was meant here. Of course in my real world problem the model is not just a linear layer. I guess with a detach solution I would need to redefine my whole model with functions. I guess wrapping with requires_grad is easier then. Is there a downside to using requires_grad compared to functions (aka the detach solution)?

I still think there could be some improvement of pytorch here. The difference you mentioned between no_grad and requires_grad is not clear and I think a specific parameter or environment for the “propagate gradient but don’t adjust weights” use-case would be helpful. It would make this much more explicit (I prefer explicitness).
What is the right place to propose such a feature request? Do you see any chance for such a feature request to succeed?

You could create a feature request on GitHub and explain the shortcomings you are seeing as well as a proposal (if you have a good idea how it should look).

It depends on the request and if it would make all use cases more explicit or could yield ambiguous behavior in specific use cases.