Custumize the Backpropagation phase of a neural network to ignore the update of some parameters

You could try to use torch.fx to trace the model and replace the modules with your custom (partially frozen layers).
I’ve used this tutorial to replace the linear layers in this code snippet:

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

class MyLinear(nn.Module):
    def __init__(self, in_features, out_features, bias, split_index):
        self.weight = nn.Parameter(torch.randn(split_index, in_features))
        self.register_buffer("frozen_weight", torch.randn(out_features-split_index, in_features))
        # do the same with the bias if needed
        if bias is not None:
            self.bias = bias
            self.register_parameter("bias", None)
    def forward(self, x):
        weight =, self.frozen_weight), dim=0)
        out = F.linear(x, weight, self.bias)
        return out

class NN(nn.Module):
    def __init__(self,input_size,output_size):
        super(NN, self).__init__()
        self.linear1=nn.Linear(input_size, 3)
        self.linear2=nn.Linear(3, 3)
        self.linear3=nn.Linear(3, 2)
        self.linear4=nn.Linear(2, output_size)

    def forward(self, x,eq_indexes_list):
        x = self.linear1(x)
        x = nn.functional.relu(x)
        x = self.linear2(x)
        x = nn.functional.relu(x)
        x = self.linear3(x)
        x = nn.functional.relu(x)
        x = self.linear4(x)
        return x

def _parent_name(target):
    Splits a qualname into parent path and last atom.
    For example, `` -> (``, `baz`)
    *parent, name = target.rsplit('.', 1)
    return parent[0] if parent else '', name

def replace_node_module(node, modules, new_module):
    assert(isinstance(, str))
    parent_name, name = _parent_name(
    modules[] = new_module
    setattr(modules[parent_name], name, new_module)

def matches_module_name(name, node):
    if not isinstance(node, torch.fx.Node):
        return False
    if node.op != 'call_module':
        return False
    if not isinstance(, str):
        return False
        if name in
            return True   
    return False

model = NN(3,1)
traced = torch.fx.symbolic_trace(model)

modules = dict(traced.named_modules())
name = "linear"
idx = 0

new_graph = copy.deepcopy(traced.graph)
list_of_indexes=[None, 2, 1, None]

for node in traced.graph.nodes:
    if matches_module_name(name, node):
        print("match for ", node)
        replace_index = list_of_indexes[idx]
        if not replace_index:
            print("skipping since replace_index is empty")
            idx += 1
        idx += 1
        print("replace_index: ", replace_index)
        ref = modules[]
        lin = MyLinear(ref.in_features, ref.out_features, ref.bias, replace_index)
        replace_node_module(node, modules, lin)

new_traced = torch.fx.GraphModule(traced, new_graph)
# GraphModule(
#   (linear1): Linear(in_features=3, out_features=3, bias=True)
#   (linear2): MyLinear()
#   (linear3): MyLinear()
#   (linear4): Linear(in_features=2, out_features=1, bias=True)
# )

Note that I changed the list_of_indexes a bit as this version fits better into my code example.
You can of course expand the solution a bit if needed.