Custumize the Backpropagation phase of a neural network to ignore the update of some parameters

I want to update only the parameters of the selected neurons and freeze the parameters of the other neurons during the backpropagation step knowing that they are geographically separated.

Is it possible to split the weight tensor of each layer into two tensors and set the “required grad” of the one containing the red(see image below) parameters to False?
how can we do otherwise

Capture

Yes, creating different tensors (trainable and frozen ones) should work. Here is a small example showing this use case:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features-2, in_features))
        self.register_buffer("frozen_weight", torch.randn(2, in_features))
        
        # do the same with the bias if needed
        if bias:
            self.bias = nn.Parameter(torch.randn(out_features))
        else:
            self.register_parameter("bias", None)
        
    def forward(self, x):
        weight = torch.cat((self.weight, self.frozen_weight), dim=0)
        out = F.linear(x, weight, self.bias)
        return out
    

my_linear = MyLinear(10, 5)
x = torch.randn(8, 10)
my_out = my_linear(x)

my_out.mean().backward()
print(my_linear.weight.grad)
print(my_linear.frozen_weight.grad)

print(my_linear.state_dict())

I’ve used a simple split but you can of course use a more complicated masking etc. Also, note that I haven’t changed the bias so you might also want to do it.

@ptrblck Thanks for your response ,

What if I want to modify an already existing model

Here is the code of the model presented in the image, how can we define the split_neurons() function which takes as input a list of list (which has the same size as the number of layers) of red neuron indexes and do the split

import torch
import torch.nn as nn

class NN(nn.Module):

    def __init__(self,input_size,output_size):

        super(NN, self).__init__()
        self.linear1=nn.Linear(input_size, 3)
        self.linear2=nn.Linear(3, 3)
        self.linear3=nn.Linear(3, 2)
        self.linear4=nn.Linear(2, output_size)


    def forward(self, x,eq_indexes_list):
        
        x = self.linear1(x)
        x = nn.functional.relu(x)

        x = self.linear2(x)
        x = nn.functional.relu(x)

        x = self.linear3(x)
        x = nn.functional.relu(x)

        x = self.linear4(x)

        return x

#The same model as the attached picture
model=NN(3,1)

#red neurones indexes per layer
list_of_indexes=[[],[0,1],[0],[]]

def split_neurones(list_of_indexes):
   pass

You could try to use torch.fx to trace the model and replace the modules with your custom (partially frozen layers).
I’ve used this tutorial to replace the linear layers in this code snippet:

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy


class MyLinear(nn.Module):
    def __init__(self, in_features, out_features, bias, split_index):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(split_index, in_features))
        self.register_buffer("frozen_weight", torch.randn(out_features-split_index, in_features))
        
        # do the same with the bias if needed
        if bias is not None:
            self.bias = bias
        else:
            self.register_parameter("bias", None)
        
    def forward(self, x):
        weight = torch.cat((self.weight, self.frozen_weight), dim=0)
        out = F.linear(x, weight, self.bias)
        return out


class NN(nn.Module):
    def __init__(self,input_size,output_size):
        super(NN, self).__init__()
        self.linear1=nn.Linear(input_size, 3)
        self.linear2=nn.Linear(3, 3)
        self.linear3=nn.Linear(3, 2)
        self.linear4=nn.Linear(2, output_size)


    def forward(self, x,eq_indexes_list):
        x = self.linear1(x)
        x = nn.functional.relu(x)
        x = self.linear2(x)
        x = nn.functional.relu(x)
        x = self.linear3(x)
        x = nn.functional.relu(x)
        x = self.linear4(x)
        return x


def _parent_name(target):
    """
    Splits a qualname into parent path and last atom.
    For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
    """
    *parent, name = target.rsplit('.', 1)
    return parent[0] if parent else '', name


def replace_node_module(node, modules, new_module):
    assert(isinstance(node.target, str))
    parent_name, name = _parent_name(node.target)
    modules[node.target] = new_module
    setattr(modules[parent_name], name, new_module)


def matches_module_name(name, node):
    if not isinstance(node, torch.fx.Node):
        return False
    if node.op != 'call_module':
        return False
    if not isinstance(node.target, str):
        return False
    else:
        if name in node.target:
            return True   
    return False


model = NN(3,1)
traced = torch.fx.symbolic_trace(model)

modules = dict(traced.named_modules())
name = "linear"
idx = 0

new_graph = copy.deepcopy(traced.graph)
list_of_indexes=[None, 2, 1, None]

for node in traced.graph.nodes:
    if matches_module_name(name, node):
        print("match for ", node)
  
        replace_index = list_of_indexes[idx]
        if not replace_index:
            print("skipping since replace_index is empty")
            idx += 1
            continue
        idx += 1
        print("replace_index: ", replace_index)
        
        ref = modules[node.target]
        lin = MyLinear(ref.in_features, ref.out_features, ref.bias, replace_index)
        replace_node_module(node, modules, lin)
        node.replace_all_uses_with(lin)
        new_graph.erase_node(node)

new_traced = torch.fx.GraphModule(traced, new_graph)
print(new_traced)
# GraphModule(
#   (linear1): Linear(in_features=3, out_features=3, bias=True)
#   (linear2): MyLinear()
#   (linear3): MyLinear()
#   (linear4): Linear(in_features=2, out_features=1, bias=True)
# )

Note that I changed the list_of_indexes a bit as this version fits better into my code example.
You can of course expand the solution a bit if needed.

Thanks @ptrblck for your help !

I modified your code a bit so that it changes the model parameters status(forzen or not) over time (no random weights on MyLinear),

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import torch.fx
from torch import optim 
from tqdm import tqdm

class MyLinear(nn.Module):
    def __init__(self,bias,params, split_index):
        super().__init__()

        self.register_buffer("frozen_weight",params[0:split_index])
        self.weight = nn.Parameter(params[split_index:])
        
        # do the same with the bias if needed
        if bias is not None:
            self.bias = bias
        else:
            self.register_parameter("bias", None)
        
    def forward(self, x):
        weight = torch.cat((self.frozen_weight,self.weight), dim=0)
        out = F.linear(x, weight, self.bias)
        return out


class NN(nn.Module):
    def __init__(self,input_size,output_size):
        super(NN, self).__init__()
        self.linear1=nn.Linear(input_size, 3)
        self.linear2=nn.Linear(3, 3)
        self.linear3=nn.Linear(3, 2)
        self.linear4=nn.Linear(2, output_size)


    def forward(self, x):
        x = self.linear1(x)
        x = nn.functional.relu(x)
        x = self.linear2(x)
        x = nn.functional.relu(x)
        x = self.linear3(x)
        x = nn.functional.relu(x)
        x = self.linear4(x)
        return x

def _parent_name(target):
    """
    Splits a qualname into parent path and last atom.
    For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
    """
    *parent, name = target.rsplit('.', 1)
    return parent[0] if parent else '', name


def replace_node_module(node, modules, new_module):
    assert(isinstance(node.target, str))
    parent_name, name = _parent_name(node.target)
    modules[node.target] = new_module
    setattr(modules[parent_name], name, new_module)

def matches_module_name(name, node):
    if not isinstance(node, torch.fx.Node):
        return False
    if node.op != 'call_module':
        return False
    if not isinstance(node.target, str):
        return False
    else:
        if name in node.target:
            return True   
    return False


def freeze(model,list_of_indexes):
  traced = torch.fx.symbolic_trace(model)
  modules = dict(traced.named_modules())
  name = "linear"
  idx = 0

  new_graph = copy.deepcopy(traced.graph)

  for node in traced.graph.nodes:
      if matches_module_name(name, node):    
          replace_index = list_of_indexes[idx]
          if not replace_index:
              idx += 1
              continue
          idx += 1          
          ref = modules[node.target]
          lin = MyLinear(ref.bias,ref.weight,replace_index)
          replace_node_module(node, modules, lin)
          node.replace_all_uses_with(lin)
          new_graph.erase_node(node)

  new_traced = torch.fx.GraphModule(traced, new_graph)
  return new_traced
def forwardprop_and_backprop(model,data):
  # define loss function
  criterion = nn.CrossEntropyLoss()

  # define optimizer
  optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  optimizer.zero_grad()

  output=model(data)
  loss = criterion(torch.squeeze(output), torch.randn(20).float())
  loss.backward()
  optimizer.step()

let’s take the code below as an example,
to do so, I need to unfreeze all forzen weights before reusing the function “freeze” again with new “list_of_indexes”,what is the most optimized way to code the unfreeze_all() function to do this?

import numpy as np

data=torch.randn(20,2)
original_model=NN(2,1)

for i in range(10):
  list_of_indexes=np.random.randint(2, size=4)
  model_with_freeze=freeze(original_model,list_of_indexes)
  forwardprop_and_backprop(model_with_freeze,data)

  #function to code
  unfreeze_all(model_with_freeze)

I have another question, I tested to see if training the frozen model with a large “for” loop saves computation time compared to training the original model,

data=torch.randn(20,2)
list_of_indexes=[1, 2, 1, None]
original_model=NN(2,1)
model_with_freeze=freeze(original_model,list_of_indexes)
#Test1
for epoch in tqdm(range(100000)):
  forwardprop_and_backprop(original_model,data)
#Test2
for epoch in tqdm(range(100000)):
  forwardprop_and_backprop(model_with_freeze,data)

Strangely, Test2(01:36) takes more time than Test1(01:28) even though in Test1 there are some avoided gradient calculations.Is this normal?

Thanks