How to backward only a subset of neural network parameters? (avoid retain_graph=True)

Hey;

At the beginning of the training, I have created a neural network NN.

I create optimizer by

optimizer = optim.Adam(NN.parameters(), lr=1e-3)

During the training, I’m adding to new layers to this network. (Imagining dynamically increasing number of layers of residual network).

optimizer.add_param_group({"params": new_layer_params}) at each iteration when new layer created.

However, when I add a new layer to my NN ; I want to train solely my new layer parameters for few steps; that is, ignoring the previous layer parameters and only trains the newest layer for T steps (only optimize the newly added parameters). Then after this solely training, I will start fully train my NN (optimize all layers parameters).

How should I do this?

My current attentative approaches:

(1) create a list of optimizers; where each optimizer is responsible to optimize the parameters of each layer.

opt = [] # collection of parameters

optimizer = Adam(NN.parameters()) # my first layer parameters
opt.append(optimizer)

for l in range(total_number_of_layers):
    # add new layers to NN
    ... some code here
 
   # add new optimizer to collections; this optimizer only responsible take care of new layer parameters
    opt.append(Adam(new_layer_parameters)) 

    # only train new layers
    for t in range(T1):
        opt[-1].zero_grad()
        loss = get_loss(...)
        loss.backward(retain_graph=True)
        opt[-1].step() # only update new parameters

    # Fully train
    for t in range(T2):
        for o in opt:
             o.zero_grad()
        loss = get_loss(...)
        loss.backward(retain_graph=True)
        for o in opt:
             o.step()

However, I notice that retain_graph=True is extremely memory inefficient. My program takes very large amount of memory.

I’m thinking that when I only train new layers; loss.backward() is take gradient w.r.t. all parameters (including old layers). I’m thinking if its possible that for this snip code; I can detach() old layers and only backprogate the new layer parameters.

=========================================

according to the first reply;

Continuing the discussion from How to backward only a subset of neural network parameters? (temporally detach some parameters):

I have changed the code to

opt = [] # collection of parameters

optimizer = Adam(NN.parameters()) # my first layer parameters
opt.append(optimizer)

for l in range(total_number_of_layers):
    # add new layers to NN
    ... some code here
 
   # add new optimizer to collections; this optimizer only responsible take care of new layer parameters
    opt.append(Adam(new_layer_parameters)) 

    # detach previous parameters
    for name, param in NN.name_parameters():
        if name in previous:  # some code here to check 
             param.require_grad=False
    # only train new layers
    for t in range(T1):
        opt[-1].zero_grad()
        loss = get_loss(...)
        loss.backward(retain_graph=True)
        opt[-1].step() # only update new parameters

    # attach them back
    for param in NN.parameters():
        param.require_grad=True

    # Fully train
    for t in range(T2):
        for o in opt:
             o.zero_grad()
        loss = get_loss(...)
        loss.backward()
        for o in opt:
             o.step()

However, at first iteration of l, we skip # only train new layers part because we have no previous layers, we run the code in # Fully train part. Then at second iteration, we run # only train new layers part; it once again raises Error that I have to use retain_graph=True

1 Like

You should achieve this by setting requires_grad to False for the parameters you don’t want to train.

Thanks for the reply. I have added the requires_grad=True but still run into Error that I must use retain_graph=True; you can see my updated post

I don’t think you need multiple optimisers.

You could try to instantiate the optimiser with the parameters that have required_grad = True:

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wdecay)

When you add new layers to your network, you should add the parameters of the new layer via the add_param_group() method.

At this point you will have an optimiser which can optimise the new network (old + new layer).

What’s missing is to ability to freeze the old layers. And I think this can be accomplished just by setting requires_grad=False to the parameters you want to freeze.

Check this as a simple example:

import torch
import torch.optim as optim

w1 = torch.randn(3, 3)
w1.requires_grad = True

o = optim.Adam([w1])
print(o.param_groups)

This prints:

[{'params': [tensor([[-0.7050,  1.0218, -0.7735],
        [-0.0213, -1.2361,  1.2676],
        [-0.8811, -0.5080,  0.4427]], requires_grad=True)], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}]

If then you set w1.requires_grad=False the o.param_groups contains:

[{'params': [tensor([[-0.7050,  1.0218, -0.7735],
        [-0.0213, -1.2361,  1.2676],
        [-0.8811, -0.5080,  0.4427]])], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}]

Also computing the loss twice and backward() twice seems to me a bit strange but I may be missing something. Why don’t you just freeze the layers you don’t want, train on the entire/part of the dataset for one/ more epochs, unfreeze the network and train the whole again.

You are forced to retain the graph with whatever this means (e.g. memory footprint) just because you are computing gradients for the loss twice on the same step

Thanks for the help! The trick for optimizer is so helpful !

Yes; my plan would be (1) freeze the layer I don’t want, train for few epochs, then unfreeze the network and (2)train the whole again.

So (1) and (2) have two loss functions here.

But in this case, did you mean that I must have retain_graph=True ? Because so far I must have both losses to retain_graph. This leads to out of memory

If you have parameters with requires_grad = False backward() does not compute a gradient on those.

So you could try something like:

def train():
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)

        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

for epoch in range(1, 6):
     if epoch == 2:
         for name, param in NN.name_parameters():
             param.require_grad = False

         # 1. add new layers to the model
         # 2. call add_param() to the optimizer
     elif epoch == 4:
         for name, param in NN.name_parameters():
             param.require_grad = True

     train()

So this would train 1 epoch on the original model. At epoch 2 new layers will be added and only the new layers will be trained until epoch 4, when the entire model un-freezes and get trained for two more epochs

3 Likes

Thanks for the detailed solution; however, I’ve just realized that my project is slightly different. My two kinds of training process. (1) training on the new layers. (2) training on old layers. They are both trained on different batches with some slightly different train procedures. One requires a loop in one dataset; another doesn’t

They would be like something like this.



# Train on new layers requires a loop on a dataset
for data in dataset_1():
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

 # Train on all layers doesn't loop the dataset
optimizer.zero_grad()
output = model(dataset2)
loss = criterion(output, target)
loss.backward()
optimizer.step()

However, I create a simple toy linear regression code; it seems to me it doesn’t require retain_graph=True. However, in my true code; which are very similar; requires retain_graph=True
This is quite confusing that I’m not sure if I fully understand how does backward() works

x = torch.randn(100,1)
y = 1 + 2*x + torch.randn(x.shape[0]).reshape(100,1)
loss = lambda x: ((x-y).pow(2)).mean()
n_dim = 1
F = torch.nn.Sequential()



for l in range(0, 10):
    F.add_module("%d"%l, nn.Linear(n_dim, 1))
    
    if l == 0: # first layer create optimizer
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, F.parameters()), lr=1e-3)
    else: # new layers added to optimizer
        params = [param for name, param in F.named_parameters() if "%d." % l in name]
        optimizer.add_param_group({"params": params})
        
        
    # freeze
    for name, param in F.named_parameters():
        if "%d." % l not in name:
            param.requires_grad = False

    # train new layers only
    for _ in tqdm(range(100)):

        optimizer.zero_grad()
        pred_y = F(x)
        #loss = ((pred_y-y).pow(2)).mean()
        loss = criteria(pred_y)
        loss.backward()

        optimizer.zero_grad()
        
        
    # unfreeze
    for param in F.parameters():
        param.requires_grad = True

    # train on all layers
    for _ in tqdm(range(100)):
        optimizer.zero_grad()
        pred_y = F(x)
        #loss = ((pred_y-y).pow(2)).mean()
        loss = criteria(pred_y)
        loss.backward()

        optimizer.zero_grad()

Oh; I have finally found out why it’s working now !!! It’s because the way I compute loss some how involves some globally shared parameters that requires retain_graph=True ! Thanks so much for all tips

1 Like

Hi, how do you find tge globally shared parameters that requires retain_graph=True? I met the same problems with you. Thanks!