Multiple losses are causing gradients to "leak"

Hi all,

I am training a model comprised of a convolutional AE (supposed to learn to compress and decompress the input images) and TIMM’s resnet50_FCN which will then take the AE output and segment them.

In my training set up I have two different cases: Either the AE is trained using only the reconstruction error, or it is trained using the combination (sum) of the reconstruction error, and the task error. For this I have two optimizers:

    task_optimizer = torch.optim.Adam(model.task_net.parameters(), lr=args.lr)
    AE_optimizer = torch.optim.Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=args.lr)

And this is my training loop:

# Get images and labels and send them to device
                images, labels = batch[0].to(device, non_blocking=True), batch[1].to(device, non_blocking=True)
                labels = labels.squeeze()
                
                # Zero the gradients of both optims
                task_optimizer.zero_grad()
                AE_optimizer.zero_grad()


                #Forward pass
                z = model.encoder(images)
                x_hat = model.decoder(z)
                out = model.task_net(x_hat)

                if dataset_id in [3, 4, 5, 6]: # the segmentation task network outputs a dict  
                    out = out["out"]

                # Calculate loss, backprop and update weights
                t_error = task_loss(out, labels) # task loss 
                r_error = recon_loss(x_hat, images)  # reconstruction loss   
                AE_error = r_loss_w*r_error + t_loss_w*t_error # AE trained with a combination of losses 

                t_error.backward(retain_graph=True) # Don't delete graph for next backward pass
                AE_error.backward(retain_graph=False)

                task_optimizer.step()
                AE_optimizer.step()

As you can see in the snippet above, when only the reconstruction error is used, t_loss_w is set to 0. However, when checking the results of both models, they shared quasi-identical performance. After trying a bunch of stuff, my supervisor checked the code and said that, even though I am multiplying the task error by 0, and making the combined loss equal to just the reconstruction loss, the AE is still being updated using the gradients from the task net (or something like that).

We then changed it to this super inefficient code which is supposed to work:

# Get images and labels and send them to device
                images, labels = batch[0].to(device, non_blocking=True), batch[1].to(device, non_blocking=True)
                labels = labels.squeeze()
                
                #Forward pass
                z = model.encoder(images)
                x_hat = model.decoder(z)
                out = model.task_net(x_hat)

                #print(z.shape)

                if dataset_id in [3, 4, 5, 6]: # the segmentation task network outputs a dict  
                    out = out["out"]

                # Do updates sequentially so that the non Task-aware AE gets updated without residual gradients from the task loss 
                # Update TASK NET
                # Calculate loss
                t_error = task_loss(out, labels) # task loss 
                # clear gradients
                task_optimizer.zero_grad()
                # backward
                t_error.backward()
                # update parameters
                task_optimizer.step()


                #Forward pass again (this is super inneficient but it does not work otherwise?)
                z = model.encoder(images)
                x_hat = model.decoder(z)
                out = model.task_net(x_hat)

                
                if dataset_id in [3, 4, 5, 6]: # the segmentation task network outputs a dict  
                    out = out["out"]
                # Update AE
                # Calculate loss
                r_error = recon_loss(x_hat, images)  # reconstruction loss   
                AE_error = t_loss_w*task_loss(out, labels) 
                AE_error += r_loss_w*r_error  # AE trained with a combination of losses 
                # clear gradients
                AE_optimizer.zero_grad()
                # backward
                AE_error.backward()
                # update parameters
                AE_optimizer.step()

Could someone please explain to me why these two cases are different and why the first training loop is not doing what intended? I have been trying to get my head around it but I don’t seem to figure it out.

Thanks a lot in advance!!