MultiTask training without summing losses

Hello,

I am trying to implement an adaption to the classical Retinanet object detector. Retinanet could be considered a multitask network with classification and localization as separate tasks. Retinanet share a common backbone with 2 heads for localization and classification.

The normal way of train this network is to calculate separate losses L_class and L_loc for both tasks, summing them:
L_total = L_class + L_loc.

The training procedure is then:

optimizer.zero_grad()
Lclass, L_loc = loss(retinanet(x) )
L_total.backward()
optimizer.step()

optimizer.zero_grad()
Lclass, L_loc = loss(retinanet(x) ) 
L_total.backward()
optimizer.step()


I would like to have 2 separate optimizers for the classification task and the localization task. My question is how to do this correctly to make sure that both heads are trained with separate optimizers ( opt_class, opt_loc) . I would propose this to be in my training loop:

# Localization first
opt_loc.zero_grad()
_, L_loc = loss(retinanet(x) ) 
L_loc.backward()
opt_loc.step()

# Classification second
opt_class.zero_grad()
L_class,_ = loss ( retinanet(x))
L_class.backward()
opt_class.step()

Would this be a sound approach and, is it possible to do this without creating addtional gpu memory usage ?