How to customize the "backward" function to not calculate gradients and not update a selected sub-tensor of a parameter tensor of a layer

Let’s take the model below :

import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import torch

class NeuralNetwork(nn.Module):

    def __init__(self,input_size,output_size):
        super(NeuralNetwork, self).__init__()
        self.linear1=nn.Linear(input_size, 10)
        self.linear2=nn.Linear(10, output_size)

    def forward(self, x):

        x = self.linear1(x)
        x = nn.functional.relu(x)

        x = self.linear2(x)

        return x

model=NeuralNetwork(5,2)
model.state_dict()["linear1.weight"].shape

# output :
# torch.Size([10, 5])

Run random data through the model and optimize the parameters(with just one forward and backward propagation):

input=torch.randn(200,5)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1, momentum=0.9)
output=model(input)
loss = criterion(output,torch.randn(200,2))

# update
loss.backward()
optimizer.step()

for example, for the first layer,I want to update only a sub-tensor of the weights tensor :

for a given idx, model.state_dict()["linear1.weight"][0:idx] should not be updated,(and apply the same reasoning for other layers,with different idx)

I don’t want to replace the gradients after being calculated by 0 because my goal is to avoid computing them for optimization reasons,

any help would be appreciated

In that case you could create a “fixed” and a “trainable” tensor in a custom linear layer and concatenate them in each forward pass. This would make sure that only the trainable part gets valid gradient and the parameter updates.
This post gives you an example of such a layer and replaces it via torch.fx in another model.

Thanks @ptrblck,

I took a look at this post and from the results presented there, operations like torch.cat can consistently slow down the forward and backward propagation,

Is it the only way to do it,because if I understood correctly, we can’t do withou torch.cat calls

Thanks in advance

No, it’s not the only way as you could also:

  • zero out the gradients assuming your optimizer does not use internal running states to update parameters or
  • restore the parameters after a full update.

However, since you don’t want to compute the gradients for the frozen part of your parameters at all, the I think the torch.cat is the only one meeting your requirements.

1 Like

thank you for the clarification @ptrblck ,

do you think it is feasible to improve the code of the post so that training the model with fixed tensors will be faster than training the original model withoud fixed ones(since I want to use it for optimization reasons)? it seems to me that the code is already optimal…
I know my question is a bit general, but I want an expert opinion to know if it’s worth digging deeper on this,

Thanks

I think the fastest approach would depend on your actual use case and I would suggest to profile the discussed methods.
In particular, using the torch.cat approach with the out argument could be beneficial if only a small part of the actual gradients should be calculated. On the other hand, you might not see huge benefits of avoiding the computation of the “frozen” gradients if the actual workload is tiny. Since you are concerned about the optimal performance, a profile would be the right approach to see which method would add the most overhead.

1 Like

Hi @ptrblck I have a question :

Why is it necessary to use torch.fx here ?

Is it possible to simply replace the selected layer with the customized one which will freeze the desired parameters ( using setattr(model, module_name, new_module)) without losing the weights in the transition?

You don’t need to use torch.fx, but it might be a convenient way to replace layers “automatically”.
Yes, you could also replace layers manually by replacing the internal attributes.

1 Like