Is this a proper way to update an optimizer for 3 models?

Hello all.

I have three nn.Module joined by using a nn.ModuleList

decoder = nn.ModuleList([create_model() for _ in range(3)])

I have one optimizer with the trainable parameters of the 3 decoder (remind is a list of 3 different models)

optimizer = Adam(decoder.parameters(), lr=1e-3, weight_decay=1e-5)

Each decoder ( decoder[0], decoder[1] and decoder[2] ) has different inputs and I need update their weights separately.

This is a brief version of how my loop looks.

for epoch in range(100):
    for features in dataloader: # features is a list of 3 inputs.
        for idx, input in enumerate(features): 
            
            # Forward pass
            loss = decoder[idx](input) # input[0] into decoder[0], input[1] into decoder[1] and input[2] into decoder[2]

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
           
            # Adam step
            optimizer.step()

Am I updating the weight of each model with their corresponding error correctly?

If you prefer. here is a more complex version of my training loop:

for e in range(100):
    for inputs in dataloader:
        with torch.no_grad():
            features = feature_extractor_resnet50(inputs) # output of layers [1, 2, 3]
    
        for idx, input in enumerate(features):

            # Forward pass
            loss = decoder[idx](input)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()

            # Adam step
            optimizer.step()

I think, you should have optimizer.zero_grad(), optimizer.step() as below to update the parameters altogether after 3 forward passes to 3 decoders.

for e in range(100):
    for inputs in dataloader:
        with torch.no_grad():
            features = feature_extractor_resnet50(inputs) # output of layers [1, 2, 3]

        # zero gradients
        optimizer.zero_grad()

        for idx, input in enumerate(features):

            # Forward pass
            loss = decoder[idx](input)

            # Backward pass
            loss.backward()
    
        # Adam step  (only once for every input)
        optimizer.step()
1 Like

What you said makes sense, although I’m still testing it, I can say that training time was reduced from 40 to 30 minutes which is a great news.

It seems that the performance keeps similar with your methodology. I think it’s ok.

1 Like

Sorry but it doesn’t work as well as I though. After 3 test using the same seed, it seems that your proposal is a bit behind that mine. This is 91% of performance (your) vs 93% performance (mine). Why could be this. For me, your implementation make sense…maybe this normal but is very rare after 3 tests.

I open the complete details of the training loop.

  1. Firstly, I extract the output of resnet50.layers1, resnet50.layer2 and resnet50.layer3 using certain data.
  2. I have 3 decoders, one for each output of the resnet. This is to say, decoder1 use resnet50.layer1, decoder2 use resnet50.layer2 and decoder3 use resnet50.layer3.
  3. Then, weights of each decoder need be updated.
model.decoder.train()
    train_loss = list()
    for inputs in dataloader:
        inputs = inputs.to(device)
        with torch.no_grad():
            features = model.feature_extractor_resnet50(inputs) # output of layers [1, 2, 3] of the resnet

        for idx, feat in enumerate(features): # features is a list of 3 items
            # Forward
            loss = model.decoder[idx](feat) # this is, decoder[0] use resnet.output_layer[0], decoder[1] use resnet.output_layer[1] and so son
            
            # Backward
            optimizer.zero_grad()
            loss.backward()

            # Update weights
            torch.nn.utils.clip_grad_norm_(model.decoder[idx].parameters(), 1e0)
            optimizer.step()

            train_loss.append(loss.item())

Thanks for the help :slight_smile:

I couldn’t think of a concrete reason for this behavior.
I will think more and come back if I find some valid explanations.

1 Like

Yes, tomorrow I will test it again. Maybe I did a mistake. I let you know. Thanks

1 Like

You might want to try to just sum up the losses, zero the gradients, and pass them to the backward function as opposed to looping over three modules. This approach would draw more or less the same computation graph.

I assume you want to train a multitask problem, it’s worth to try to multiply each loss with certain weight.

1 Like

Hello. I will check to sum the losses. It could work.
About pondering the losses its a good catch.

Hello all.

Well after a long time I would like to re-open this thread because after trying extensive tests there is someting I can not understand.

I will refresh what I am doing:

  • I have to train three decoders
  • I have one optimizer for the three decoders
  • One epoch is considered when the three decoders are used

Here is the code that creates the decoders and optimizer function

decoder_list = nn.ModuleList([decoder1(), decoder2(), decoder3()])
..
optimizer = Adam(decoder_list.parameters(), lr=1e-3)

Here I show a simplex version of my training loop

for epoch in range(100):
    for features in dataloader: #  list of 3 inputs, one for each decoder.
        for id, input in enumerate(list_features): 
            # Forward pass
            loss = decoder_list[id](input)
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(decoder_list[id].parameters(), 1.)
            # Adam step
            optimizer.step()

Following the guidelines of @InnovArul I tried this version of the training loop but the performance is a bit worse

for epoch in range(100):
    for features in dataloader: #  List of three elements, one per decoder.
        # Zero gradient per each data
        optimizer.zero_grad()
        for id, input in enumerate(list_features): 
            # Forward pass per each decoder
            loss = decoder_list[id](input)
            # Backward pass per each decoder
            loss.backward()

      # After the full epoch, clip gradients of the 3 decoders
      torch.nn.utils.clip_grad_norm_(decoder_list.parameters(), 1.)
      # Adam step
      optimizer.step()

Then I followed the guidelines of @ksmdanl through accumulating the losses but same as before, performance is a bit worse

for epoch in range(100):
    for features in dataloader: #  List of three elements, one per decoder.
        # Zero gradient per each data
        optimizer.zero_grad()
        loss = 0.0 # would be the accumulation of the three decoders
        for id, input in enumerate(list_features): 
            # Forward pass per each decoder
            loss += decoder_list[id](input)

      # Backward after the full epoch
      loss.backward()
      torch.nn.utils.clip_grad_norm_(decoder_list.parameters(), 1.)
      # Adam step
      optimizer.step()

Maybe I am making the mistake in clip_grad_norm_ ? Why just is it working the first method? for me, the other ones make sense.

Thanks