I want to freeze selected parameters of an existing Pytorch model, I used the torch fx symbolic tracer to capture the model after its creation and replace the layer that contains the selected parameters to be frozen, with a custom layer,
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import torch.fx
from torch import optim
from tqdm import tqdm
class MyLinear(nn.Module):
def __init__(self,bias,params, split_index):
super().__init__()
self.register_buffer("frozen_weight",params[0:split_index])
self.weight = nn.Parameter(params[split_index:])
# do the same with the bias if needed
if bias is not None:
self.bias = bias
else:
self.register_parameter("bias", None)
def forward(self, x):
weight = torch.cat((self.frozen_weight,self.weight), dim=0)
out = F.linear(x, weight, self.bias)
return out
class NN(nn.Module):
def __init__(self,input_size,output_size):
super(NN, self).__init__()
self.linear1=nn.Linear(input_size, 3)
self.linear2=nn.Linear(3, 3)
self.linear3=nn.Linear(3, 2)
self.linear4=nn.Linear(2, output_size)
def forward(self, x):
x = self.linear1(x)
x = nn.functional.relu(x)
x = self.linear2(x)
x = nn.functional.relu(x)
x = self.linear3(x)
x = nn.functional.relu(x)
x = self.linear4(x)
return x
def _parent_name(target):
"""
Splits a qualname into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
"""
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name
def replace_node_module(node, modules, new_module):
assert(isinstance(node.target, str))
parent_name, name = _parent_name(node.target)
modules[node.target] = new_module
setattr(modules[parent_name], name, new_module)
def matches_module_name(name, node):
if not isinstance(node, torch.fx.Node):
return False
if node.op != 'call_module':
return False
if not isinstance(node.target, str):
return False
else:
if name in node.target:
return True
return False
def freeze(model,list_of_indexes):
traced = torch.fx.symbolic_trace(model)
modules = dict(traced.named_modules())
name = "linear"
idx = 0
new_graph = copy.deepcopy(traced.graph)
for node in traced.graph.nodes:
if matches_module_name(name, node):
replace_index = list_of_indexes[idx]
if not replace_index:
idx += 1
continue
idx += 1
ref = modules[node.target]
lin = MyLinear(ref.bias,ref.weight,replace_index)
replace_node_module(node, modules, lin)
node.replace_all_uses_with(lin)
new_graph.erase_node(node)
new_traced = torch.fx.GraphModule(traced, new_graph)
return new_traced
def forwardprop_and_backprop(model,data):
# define loss function
criterion = nn.CrossEntropyLoss()
# define optimizer
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
output=model(data)
loss = criterion(torch.squeeze(output), torch.randn(20).float())
loss.backward()
optimizer.step()
I tested to see if training(just one forward and backward pass) the model with frozen weights with a large “for” loop saves computation time compared to training the original model(with unfrozen weights)
data=torch.randn(20,2)
list_of_indexes=[1, 2, 1, None]
original_model=NN(2,1)
model_with_freeze=freeze(original_model,list_of_indexes)
#Test1
for epoch in tqdm(range(100000)):
forwardprop_and_backprop(original_model,data)
#Test2
for epoch in tqdm(range(100000)):
forwardprop_and_backprop(model_with_freeze,data)
I assumed the network with frozen weights would be faster to train because there are fewer parameters to tune. However, this is the opposite. The frozen network takes 1min36 while the unfrozen takes 1min28. Why would this be?
Thanks