You could try to use torch.fx
to trace the model and replace the modules with your custom (partially frozen layers).
I’ve used this tutorial to replace the linear layers in this code snippet:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
class MyLinear(nn.Module):
def __init__(self, in_features, out_features, bias, split_index):
super().__init__()
self.weight = nn.Parameter(torch.randn(split_index, in_features))
self.register_buffer("frozen_weight", torch.randn(out_features-split_index, in_features))
# do the same with the bias if needed
if bias is not None:
self.bias = bias
else:
self.register_parameter("bias", None)
def forward(self, x):
weight = torch.cat((self.weight, self.frozen_weight), dim=0)
out = F.linear(x, weight, self.bias)
return out
class NN(nn.Module):
def __init__(self,input_size,output_size):
super(NN, self).__init__()
self.linear1=nn.Linear(input_size, 3)
self.linear2=nn.Linear(3, 3)
self.linear3=nn.Linear(3, 2)
self.linear4=nn.Linear(2, output_size)
def forward(self, x,eq_indexes_list):
x = self.linear1(x)
x = nn.functional.relu(x)
x = self.linear2(x)
x = nn.functional.relu(x)
x = self.linear3(x)
x = nn.functional.relu(x)
x = self.linear4(x)
return x
def _parent_name(target):
"""
Splits a qualname into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
"""
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name
def replace_node_module(node, modules, new_module):
assert(isinstance(node.target, str))
parent_name, name = _parent_name(node.target)
modules[node.target] = new_module
setattr(modules[parent_name], name, new_module)
def matches_module_name(name, node):
if not isinstance(node, torch.fx.Node):
return False
if node.op != 'call_module':
return False
if not isinstance(node.target, str):
return False
else:
if name in node.target:
return True
return False
model = NN(3,1)
traced = torch.fx.symbolic_trace(model)
modules = dict(traced.named_modules())
name = "linear"
idx = 0
new_graph = copy.deepcopy(traced.graph)
list_of_indexes=[None, 2, 1, None]
for node in traced.graph.nodes:
if matches_module_name(name, node):
print("match for ", node)
replace_index = list_of_indexes[idx]
if not replace_index:
print("skipping since replace_index is empty")
idx += 1
continue
idx += 1
print("replace_index: ", replace_index)
ref = modules[node.target]
lin = MyLinear(ref.in_features, ref.out_features, ref.bias, replace_index)
replace_node_module(node, modules, lin)
node.replace_all_uses_with(lin)
new_graph.erase_node(node)
new_traced = torch.fx.GraphModule(traced, new_graph)
print(new_traced)
# GraphModule(
# (linear1): Linear(in_features=3, out_features=3, bias=True)
# (linear2): MyLinear()
# (linear3): MyLinear()
# (linear4): Linear(in_features=2, out_features=1, bias=True)
# )
Note that I changed the list_of_indexes
a bit as this version fits better into my code example.
You can of course expand the solution a bit if needed.