import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import torch
class NeuralNetwork(nn.Module):
def __init__(self,input_size,output_size):
super(NeuralNetwork, self).__init__()
self.linear1=nn.Linear(input_size, 10)
self.linear2=nn.Linear(10, output_size)
def forward(self, x):
x = self.linear1(x)
x = nn.functional.relu(x)
x = self.linear2(x)
return x
In that case you could create a “fixed” and a “trainable” tensor in a custom linear layer and concatenate them in each forward pass. This would make sure that only the trainable part gets valid gradient and the parameter updates. This post gives you an example of such a layer and replaces it via torch.fx in another model.
I took a look at this post and from the results presented there, operations like torch.cat can consistently slow down the forward and backward propagation,
Is it the only way to do it,because if I understood correctly, we can’t do withou torch.cat calls
zero out the gradients assuming your optimizer does not use internal running states to update parameters or
restore the parameters after a full update.
However, since you don’t want to compute the gradients for the frozen part of your parameters at all, the I think the torch.cat is the only one meeting your requirements.
do you think it is feasible to improve the code of the post so that training the model with fixed tensors will be faster than training the original model withoud fixed ones(since I want to use it for optimization reasons)? it seems to me that the code is already optimal…
I know my question is a bit general, but I want an expert opinion to know if it’s worth digging deeper on this,
I think the fastest approach would depend on your actual use case and I would suggest to profile the discussed methods.
In particular, using the torch.cat approach with the out argument could be beneficial if only a small part of the actual gradients should be calculated. On the other hand, you might not see huge benefits of avoiding the computation of the “frozen” gradients if the actual workload is tiny. Since you are concerned about the optimal performance, a profile would be the right approach to see which method would add the most overhead.
Is it possible to simply replace the selected layer with the customized one which will freeze the desired parameters ( using setattr(model, module_name, new_module)) without losing the weights in the transition?
You don’t need to use torch.fx, but it might be a convenient way to replace layers “automatically”.
Yes, you could also replace layers manually by replacing the internal attributes.