Will freezing an intermediate block influence the gradient flow?

The problem I’m facing is that I want to insert a small pre-trained model to an existing model to do something like features enhancement.
Whereas I want to know if the freezing operation (setting the requires_grad flag of parameters to False) will influence the gradient calculation especially for the layers before the inserted block.

def __init__(self, block, num_blocks, num_classes=10):
    ...
    ce = MGN_FIN()
   
    # load CE parameters
    ce = torch.nn.DataParallel(ce)
    ckp = torch.load('...')
    ce.load_state_dict(ckp['model'])

    # freeze CE layer's weights
    for name, p in ce.module.named_parameters():
        p.requires_grad = False

    # construct CE module
    self.CE = nn.Sequential(
        ce.module.block1,
        ce.module.block2,
        ce.module.block3,
    )  

...
def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.CE(out)

    out = self.layer1(out)
    out = self.CE(out)

    out = self.layer2(out)
    out = self.layer3(out)
    out = F.avg_pool2d(out, out.size()[3])
    out = out.view(out.size(0), -1)
    out = self.linear(out)
    return out

For example:
image
Any help is appreciated!

A cleaner solution would be:

out = F.relu(self.bn1(self.conv1(x)))
with torch.no_grad():
    out = self.CE(out)

However, I see that you are using the CE layer several times, and this solution is not feasible since gradient will not flow through your second CE layer, and you will not be able to train self.layer1, therefore, I suggest you either:

  1. manually clear the .grad property and set them to None by iterating on self.CE.parameters() after loss.backward() and before calling optimizer.step()

  2. restrict the range of parameters passed to the optimizer, pesudo code is:

optimizer([p for p in your_model.parameters() if p not in your_model.ce.parameters()],
                lr=learning_rate)

These two ways will not create problems stated in your question.

I am not sure whether setting “requires_grad” to False will interfere with the gradient calculation, may be do a simple experiment with two linear layers yourself?

About chain rule, gradients of input transformations wrt input should not be ignored with freezing, i.e. any module can be viewed as some chainable transformation f(x,p) (p - local parameters) and freezing implies ignoring p gradients: f(x,p.detach())

Thanks for your help!

I’ve done the following simple experiment:


class Demo(nn.Module):

    def __init__(self):
        super(Demo, self).__init__()
        self.layer1 = nn.Linear(1, 1, bias=False)
        self.layer2 = nn.Linear(1, 1, bias=False)
        self.layer3 = nn.Linear(1, 1, bias=False)

        self.layer1.weight = nn.Parameter(torch.ones([1, 1]).float())
        self.layer2.weight = nn.Parameter(torch.ones([1, 1]).float() * 2)
        self.layer3.weight = nn.Parameter(torch.ones([1, 1]).float() * 3)

        # self.layer1.register_forward_hook(lambda _, x_in, x_out: print('layer1', x_in, x_out))
        # self.layer2.register_forward_hook(lambda _, x_in, x_out: print('layer2', x_in, x_out))
        # self.layer3.register_forward_hook(lambda _, x_in, x_out: print('layer3', x_in, x_out))

        self.layer1.register_backward_hook(lambda _, grad_in, grad_out: print('layer1', grad_in, grad_out))
        self.layer2.register_backward_hook(lambda _, grad_in, grad_out: print('layer2', grad_in, grad_out))
        self.layer3.register_backward_hook(lambda _, grad_in, grad_out: print('layer3', grad_in, grad_out))

        # self.layer2.weight.requires_grad = False

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)

        return out


if __name__ == '__main__':
    model = Demo()
    optimizer = optim.SGD(params=model.parameters(), lr=0.1)
    optimizer.zero_grad()

    print(model.layer1.weight, model.layer2.weight, model.layer3.weight)
    # for i in range(10):
    x = torch.ones((1, 1))
    out = model(x)
    out.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(model.layer1.weight, model.layer2.weight, model.layer3.weight)

The result is:

# weights of linear layers
tensor([[1.]], requires_grad=True) 
tensor([[2.]], requires_grad=True) 
tensor([[3.]], requires_grad=True)

# for a linear layer: y = Wx
# gradients of backward_hook  ((gradient for W, gradient for x), gradient for y)
layer3 (tensor([[3.]]), tensor([[2.]])) (tensor([[1.]]),)
layer2 (tensor([[6.]]), tensor([[3.]])) (tensor([[3.]]),)
layer1 (None, tensor([[6.]])) (tensor([[6.]]),)

# updated weights of linear layers 
tensor([[0.4000]], requires_grad=True)
tensor([[1.7000]], requires_grad=True) 
tensor([[2.8000]], requires_grad=True)

If we set self.layer2.weight.requires_grad = False, we’ll get:

tensor([[1.]], requires_grad=True)
tensor([[2.]]) 
tensor([[3.]], requires_grad=True)

layer3 (tensor([[3.]]), tensor([[2.]])) (tensor([[1.]]),)
layer2 (tensor([[6.]]), None) (tensor([[3.]]),)
layer1 (None, tensor([[6.]])) (tensor([[6.]]),)

tensor([[0.4000]], requires_grad=True) 
tensor([[2.]])
tensor([[2.8000]], requires_grad=True)

We can see from this that setting the requires_grad flag to False will only set the gradient with respect to the weight of the layer to None, but the gradient wrt the x will be calculated and preserved for the shallower layer, which can be used for gradient descending correctly.

Good experiment. Now it proves that requires_grad=False will calculate gradients with respect to that parameter (layer2.weight) but will not set gradient on it, so that gradient can still flow to layer1, but layer2 is not affected.

There is one tiny error in your comment:

-# gradients of backward_hook  ((gradient for W, gradient for x), gradient for y)
+# gradients of backward_hook  ((gradient for x, gradient for W), gradient for y)

And @googlebot detach() is definetely inappropriate here, since it will detach the whole remaining chain from backward computation completely, so layer1 will not receive any gradients:

import torch.nn as nn
import torch.optim as optim
import torch

class Demo(nn.Module):

    def __init__(self):
        super(Demo, self).__init__()
        self.layer1 = nn.Linear(1, 1, bias=False)
        self.layer2 = nn.Linear(1, 1, bias=False)
        self.layer3 = nn.Linear(1, 1, bias=False)

        self.layer1.weight = nn.Parameter(torch.ones([1, 1]).float())
        self.layer2.weight = nn.Parameter(torch.ones([1, 1]).float() * 2)
        self.layer3.weight = nn.Parameter(torch.ones([1, 1]).float() * 3)

        # self.layer1.register_forward_hook(lambda _, x_in, x_out: print('layer1', x_in, x_out))
        # self.layer2.register_forward_hook(lambda _, x_in, x_out: print('layer2', x_in, x_out))
        # self.layer3.register_forward_hook(lambda _, x_in, x_out: print('layer3', x_in, x_out))

        self.layer1.register_backward_hook(lambda _, grad_in, grad_out: print('layer1', grad_in, grad_out))
        self.layer2.register_backward_hook(lambda _, grad_in, grad_out: print('layer2', grad_in, grad_out))
        self.layer3.register_backward_hook(lambda _, grad_in, grad_out: print('layer3', grad_in, grad_out))

        # self.layer2.weight.requires_grad = False

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out).detach()
        out = self.layer3(out)

        return out


if __name__ == '__main__':
    model = Demo()
    optimizer = optim.SGD(params=model.parameters(), lr=0.1)
    optimizer.zero_grad()

    print(model.layer1.weight, model.layer2.weight, model.layer3.weight)
    # for i in range(10):
    x = torch.ones((1, 1))
    out = model(x)
    out.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(model.layer1.weight, model.layer2.weight, model.layer3.weight)

will print:

Parameter containing:
tensor([[1.]], requires_grad=True) 
Parameter containing:
tensor([[2.]], requires_grad=True) 
Parameter containing:
tensor([[3.]], requires_grad=True)
layer3 (None, tensor([[2.]])) (tensor([[1.]]),)
Parameter containing:
tensor([[1.]], requires_grad=True) 
Parameter containing:
tensor([[2.]], requires_grad=True) 
Parameter containing:
tensor([[2.8000]], requires_grad=True)

gradient of layer1 is not calculated at all. therefore, requires_grad=True is a correct solution.

Yeah, I just used that as notation for functional style equivalent. In OOP style, weight is not a function argument, so for nn.Parameters you usually disable gradient with requires_grad.