Is this a proper way to update an optimizer for 3 models?

Hello all.

I have three nn.Module joined by using a nn.ModuleList

decoder = nn.ModuleList([create_model() for _ in range(3)])

I have one optimizer with the trainable parameters of the 3 decoder (remind is a list of 3 different models)

optimizer = Adam(decoder.parameters(), lr=1e-3, weight_decay=1e-5)

Each decoder ( decoder[0], decoder[1] and decoder[2] ) has different inputs and I need update their weights separately.

This is a brief version of how my loop looks.

for epoch in range(100):
    for features in dataloader: # features is a list of 3 inputs.
        for idx, input in enumerate(features): 
            
            # Forward pass
            loss = decoder[idx](input) # input[0] into decoder[0], input[1] into decoder[1] and input[2] into decoder[2]

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
           
            # Adam step
            optimizer.step()

Am I updating the weight of each model with their corresponding error correctly?

If you prefer. here is a more complex version of my training loop:

for e in range(100):
    for inputs in dataloader:
        with torch.no_grad():
            features = feature_extractor_resnet50(inputs) # output of layers [1, 2, 3]
    
        for idx, input in enumerate(features):

            # Forward pass
            loss = decoder[idx](input)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()

            # Adam step
            optimizer.step()

I think, you should have optimizer.zero_grad(), optimizer.step() as below to update the parameters altogether after 3 forward passes to 3 decoders.

for e in range(100):
    for inputs in dataloader:
        with torch.no_grad():
            features = feature_extractor_resnet50(inputs) # output of layers [1, 2, 3]

        # zero gradients
        optimizer.zero_grad()

        for idx, input in enumerate(features):

            # Forward pass
            loss = decoder[idx](input)

            # Backward pass
            loss.backward()
    
        # Adam step  (only once for every input)
        optimizer.step()
1 Like

What you said makes sense, although I’m still testing it, I can say that training time was reduced from 40 to 30 minutes which is a great news.

It seems that the performance keeps similar with your methodology. I think it’s ok.

1 Like

Sorry but it doesn’t work as well as I though. After 3 test using the same seed, it seems that your proposal is a bit behind that mine. This is 91% of performance (your) vs 93% performance (mine). Why could be this. For me, your implementation make sense…maybe this normal but is very rare after 3 tests.

I open the complete details of the training loop.

  1. Firstly, I extract the output of resnet50.layers1, resnet50.layer2 and resnet50.layer3 using certain data.
  2. I have 3 decoders, one for each output of the resnet. This is to say, decoder1 use resnet50.layer1, decoder2 use resnet50.layer2 and decoder3 use resnet50.layer3.
  3. Then, weights of each decoder need be updated.
model.decoder.train()
    train_loss = list()
    for inputs in dataloader:
        inputs = inputs.to(device)
        with torch.no_grad():
            features = model.feature_extractor_resnet50(inputs) # output of layers [1, 2, 3] of the resnet

        for idx, feat in enumerate(features): # features is a list of 3 items
            # Forward
            loss = model.decoder[idx](feat) # this is, decoder[0] use resnet.output_layer[0], decoder[1] use resnet.output_layer[1] and so son
            
            # Backward
            optimizer.zero_grad()
            loss.backward()

            # Update weights
            torch.nn.utils.clip_grad_norm_(model.decoder[idx].parameters(), 1e0)
            optimizer.step()

            train_loss.append(loss.item())

Thanks for the help :slight_smile:

I couldn’t think of a concrete reason for this behavior.
I will think more and come back if I find some valid explanations.

1 Like

Yes, tomorrow I will test it again. Maybe I did a mistake. I let you know. Thanks

1 Like

You might want to try to just sum up the losses, zero the gradients, and pass them to the backward function as opposed to looping over three modules. This approach would draw more or less the same computation graph.

I assume you want to train a multitask problem, it’s worth to try to multiply each loss with certain weight.

1 Like

Hello. I will check to sum the losses. It could work.
About pondering the losses its a good catch.