Why does training a model with frozen weights take longer than training the model with unfrozen weights?

I want to freeze selected parameters of an existing Pytorch model, I used the torch fx symbolic tracer to capture the model after its creation and replace the layer that contains the selected parameters to be frozen, with a custom layer,

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import torch.fx
from torch import optim 
from tqdm import tqdm

class MyLinear(nn.Module):
    def __init__(self,bias,params, split_index):

        self.weight = nn.Parameter(params[split_index:])
        # do the same with the bias if needed
        if bias is not None:
            self.bias = bias
            self.register_parameter("bias", None)
    def forward(self, x):
        weight = torch.cat((self.frozen_weight,self.weight), dim=0)
        out = F.linear(x, weight, self.bias)
        return out

class NN(nn.Module):
    def __init__(self,input_size,output_size):
        super(NN, self).__init__()
        self.linear1=nn.Linear(input_size, 3)
        self.linear2=nn.Linear(3, 3)
        self.linear3=nn.Linear(3, 2)
        self.linear4=nn.Linear(2, output_size)

    def forward(self, x):
        x = self.linear1(x)
        x = nn.functional.relu(x)
        x = self.linear2(x)
        x = nn.functional.relu(x)
        x = self.linear3(x)
        x = nn.functional.relu(x)
        x = self.linear4(x)
        return x
def _parent_name(target):
    Splits a qualname into parent path and last atom.
    For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
    *parent, name = target.rsplit('.', 1)
    return parent[0] if parent else '', name

def replace_node_module(node, modules, new_module):
    assert(isinstance(node.target, str))
    parent_name, name = _parent_name(node.target)
    modules[node.target] = new_module
    setattr(modules[parent_name], name, new_module)

def matches_module_name(name, node):
    if not isinstance(node, torch.fx.Node):
        return False
    if node.op != 'call_module':
        return False
    if not isinstance(node.target, str):
        return False
        if name in node.target:
            return True   
    return False

def freeze(model,list_of_indexes):
  traced = torch.fx.symbolic_trace(model)
  modules = dict(traced.named_modules())
  name = "linear"
  idx = 0

  new_graph = copy.deepcopy(traced.graph)

  for node in traced.graph.nodes:
      if matches_module_name(name, node):    
          replace_index = list_of_indexes[idx]
          if not replace_index:
              idx += 1
          idx += 1          
          ref = modules[node.target]
          lin = MyLinear(ref.bias,ref.weight,replace_index)
          replace_node_module(node, modules, lin)

  new_traced = torch.fx.GraphModule(traced, new_graph)
  return new_traced
def forwardprop_and_backprop(model,data):
  # define loss function
  criterion = nn.CrossEntropyLoss()

  # define optimizer
  optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

  loss = criterion(torch.squeeze(output), torch.randn(20).float())

I tested to see if training(just one forward and backward pass) the model with frozen weights with a large “for” loop saves computation time compared to training the original model(with unfrozen weights)

list_of_indexes=[1, 2, 1, None]
for epoch in tqdm(range(100000)):
for epoch in tqdm(range(100000)):

I assumed the network with frozen weights would be faster to train because there are fewer parameters to tune. However, this is the opposite. The frozen network takes 1min36 while the unfrozen takes 1min28. Why would this be?


Training a model with frozen weights may take longer than training a model with unfrozen weights because the model with frozen weights cannot improve its performance on the training data by adjusting the weights of the frozen layers. This means that the model has to rely solely on the weights of the unfrozen layers to improve its performance, which can be a slower process.

It doesn’t mean that the computation of the gradient always takes more time when some weights are frozen. I think the time required to compute the gradient depends on the complexity of the model and the size of the training data, as well as the choice of optimizer and learning rate. It is just my inference to the question.

There must be a mathematical explanation if your inference is correct

Based on your code snippet I don’t think you are comparing the forward (and backward) calls of a plain linear model vs. a partially frozen one, but are also comparing how symbolic_trace would behave by replacing linear layers with the custom ones, which also add additional torch.cat calls.

Again, you are also not profiling symbolic trace standalone, but the additional torch.cat calls.
Write micro benchmarks for each of this part and narrow down where the slowdown is coming from before trying to optimize the wrong part.