Thanks @ptrblck for your help !
I modified your code a bit so that it changes the model parameters status(forzen or not) over time (no random weights on MyLinear),
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import torch.fx
from torch import optim
from tqdm import tqdm
class MyLinear(nn.Module):
def __init__(self,bias,params, split_index):
super().__init__()
self.register_buffer("frozen_weight",params[0:split_index])
self.weight = nn.Parameter(params[split_index:])
# do the same with the bias if needed
if bias is not None:
self.bias = bias
else:
self.register_parameter("bias", None)
def forward(self, x):
weight = torch.cat((self.frozen_weight,self.weight), dim=0)
out = F.linear(x, weight, self.bias)
return out
class NN(nn.Module):
def __init__(self,input_size,output_size):
super(NN, self).__init__()
self.linear1=nn.Linear(input_size, 3)
self.linear2=nn.Linear(3, 3)
self.linear3=nn.Linear(3, 2)
self.linear4=nn.Linear(2, output_size)
def forward(self, x):
x = self.linear1(x)
x = nn.functional.relu(x)
x = self.linear2(x)
x = nn.functional.relu(x)
x = self.linear3(x)
x = nn.functional.relu(x)
x = self.linear4(x)
return x
def _parent_name(target):
"""
Splits a qualname into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
"""
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name
def replace_node_module(node, modules, new_module):
assert(isinstance(node.target, str))
parent_name, name = _parent_name(node.target)
modules[node.target] = new_module
setattr(modules[parent_name], name, new_module)
def matches_module_name(name, node):
if not isinstance(node, torch.fx.Node):
return False
if node.op != 'call_module':
return False
if not isinstance(node.target, str):
return False
else:
if name in node.target:
return True
return False
def freeze(model,list_of_indexes):
traced = torch.fx.symbolic_trace(model)
modules = dict(traced.named_modules())
name = "linear"
idx = 0
new_graph = copy.deepcopy(traced.graph)
for node in traced.graph.nodes:
if matches_module_name(name, node):
replace_index = list_of_indexes[idx]
if not replace_index:
idx += 1
continue
idx += 1
ref = modules[node.target]
lin = MyLinear(ref.bias,ref.weight,replace_index)
replace_node_module(node, modules, lin)
node.replace_all_uses_with(lin)
new_graph.erase_node(node)
new_traced = torch.fx.GraphModule(traced, new_graph)
return new_traced
def forwardprop_and_backprop(model,data):
# define loss function
criterion = nn.CrossEntropyLoss()
# define optimizer
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
output=model(data)
loss = criterion(torch.squeeze(output), torch.randn(20).float())
loss.backward()
optimizer.step()
let’s take the code below as an example,
to do so, I need to unfreeze all forzen weights before reusing the function “freeze” again with new “list_of_indexes”,what is the most optimized way to code the unfreeze_all() function to do this?
import numpy as np
data=torch.randn(20,2)
original_model=NN(2,1)
for i in range(10):
list_of_indexes=np.random.randint(2, size=4)
model_with_freeze=freeze(original_model,list_of_indexes)
forwardprop_and_backprop(model_with_freeze,data)
#function to code
unfreeze_all(model_with_freeze)
I have another question, I tested to see if training the frozen model with a large “for” loop saves computation time compared to training the original model,
data=torch.randn(20,2)
list_of_indexes=[1, 2, 1, None]
original_model=NN(2,1)
model_with_freeze=freeze(original_model,list_of_indexes)
#Test1
for epoch in tqdm(range(100000)):
forwardprop_and_backprop(original_model,data)
#Test2
for epoch in tqdm(range(100000)):
forwardprop_and_backprop(model_with_freeze,data)
Strangely, Test2(01:36) takes more time than Test1(01:28) even though in Test1 there are some avoided gradient calculations.Is this normal?
Thanks