Multiple loss function optimization

Hi Team, I have a loss function where I am calculating the loss over multiple domains and summing it up to form the final loss, and for each epoch, I am doing gradient descent on it. Can you plz let me know if the implementation is correct?

def train_AE():

for epoch in range(epochs):
    train_loss = 0
    losses = []

    for domain in range(domains):
        dataset = Dataset(dom=domain)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        for input, out, targets in dataloader:
            # data =

            input = input.float().to(device) 
            dom_out = out.float().to(device) # ground truth data
            targets = targets.long().to(device)
            logits, recon_batch = multitaskAE(input) ## model
            loss_mse = customLoss()
            loss = loss_mse(recon_batch, dom_out, logits, targets)

    final_loss = torch.sum(torch.stack(losses))
    train_loss += final_loss.item()

    if epochs % 50 == 0:
        print('====> Epoch: {} Average loss: {:.4f}'.format(
            epochs, train_loss / len(dataloader.dataset)))
        train_losses.append(train_loss / len(dataloader.dataset))

@ptrblck , can you please have a look at the block of code above? Here, I am trying to implement a loss over multiple domains and the final loss is the sum of all those. Is it the correct way to implement it or is there any better way to do it? Thank you!

Your code looks correct, but is accumulating the computation graphs of the entire dataset, which would increase the memory usage. The final_loss.backward() call would then free all stored intermediate tensors and would release the memory. If you run out of memory you could call backward() on each loss inside the DataLoader loop and scale the gradients. While this would reduce the memory the computation would be more expensive since you would call backward multiple times.

1 Like

Thank you so much for clarifying this! I found that torch.mean brings the loss values to a different range but the model performance is not degrading. So torch.mean can be used as well, right?

Yes, taking the mean of the final loss would also work as it would just scale the loss and thus gradients.

1 Like