Hi all,
I am training a model comprised of a convolutional AE (supposed to learn to compress and decompress the input images) and TIMM’s resnet50_FCN which will then take the AE output and segment them.
In my training set up I have two different cases: Either the AE is trained using only the reconstruction error, or it is trained using the combination (sum) of the reconstruction error, and the task error. For this I have two optimizers:
task_optimizer = torch.optim.Adam(model.task_net.parameters(), lr=args.lr)
AE_optimizer = torch.optim.Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=args.lr)
And this is my training loop:
# Get images and labels and send them to device
images, labels = batch[0].to(device, non_blocking=True), batch[1].to(device, non_blocking=True)
labels = labels.squeeze()
# Zero the gradients of both optims
task_optimizer.zero_grad()
AE_optimizer.zero_grad()
#Forward pass
z = model.encoder(images)
x_hat = model.decoder(z)
out = model.task_net(x_hat)
if dataset_id in [3, 4, 5, 6]: # the segmentation task network outputs a dict
out = out["out"]
# Calculate loss, backprop and update weights
t_error = task_loss(out, labels) # task loss
r_error = recon_loss(x_hat, images) # reconstruction loss
AE_error = r_loss_w*r_error + t_loss_w*t_error # AE trained with a combination of losses
t_error.backward(retain_graph=True) # Don't delete graph for next backward pass
AE_error.backward(retain_graph=False)
task_optimizer.step()
AE_optimizer.step()
As you can see in the snippet above, when only the reconstruction error is used, t_loss_w
is set to 0. However, when checking the results of both models, they shared quasi-identical performance. After trying a bunch of stuff, my supervisor checked the code and said that, even though I am multiplying the task error by 0, and making the combined loss equal to just the reconstruction loss, the AE is still being updated using the gradients from the task net (or something like that).
We then changed it to this super inefficient code which is supposed to work:
# Get images and labels and send them to device
images, labels = batch[0].to(device, non_blocking=True), batch[1].to(device, non_blocking=True)
labels = labels.squeeze()
#Forward pass
z = model.encoder(images)
x_hat = model.decoder(z)
out = model.task_net(x_hat)
#print(z.shape)
if dataset_id in [3, 4, 5, 6]: # the segmentation task network outputs a dict
out = out["out"]
# Do updates sequentially so that the non Task-aware AE gets updated without residual gradients from the task loss
# Update TASK NET
# Calculate loss
t_error = task_loss(out, labels) # task loss
# clear gradients
task_optimizer.zero_grad()
# backward
t_error.backward()
# update parameters
task_optimizer.step()
#Forward pass again (this is super inneficient but it does not work otherwise?)
z = model.encoder(images)
x_hat = model.decoder(z)
out = model.task_net(x_hat)
if dataset_id in [3, 4, 5, 6]: # the segmentation task network outputs a dict
out = out["out"]
# Update AE
# Calculate loss
r_error = recon_loss(x_hat, images) # reconstruction loss
AE_error = t_loss_w*task_loss(out, labels)
AE_error += r_loss_w*r_error # AE trained with a combination of losses
# clear gradients
AE_optimizer.zero_grad()
# backward
AE_error.backward()
# update parameters
AE_optimizer.step()
Could someone please explain to me why these two cases are different and why the first training loop is not doing what intended? I have been trying to get my head around it but I don’t seem to figure it out.
Thanks a lot in advance!!