Multiple sub-task loss backward for specific part of model

My model has 3 sub-task lossess. I want to calculate grad through each loss.backward() in 1 forward, and then use optimizer.step() for different parts of the model. The model is defined as below:

Class SomeNet(nn.Module):
    def __init__(self, num_classes):
        self.encoder = resnet50(pretrained=True)
        self.cls1 = nn.Linear(2048, num_classes)
        self.cls2 = nn.Linear(2048, num_classes)

    def forward(x):
        feature = self.encoder(x)
        output1 = self.cls1(feature)
        output2 = self.cls2(feature)
        return output1, output2

Currently I am using optimizer.zero_grad() before calling each loss.backward() to ensure only a specific part is updated. Concretly, I update encoder and cls1 with loss1, cls2 with loss2 and encoder with loss3. See demo code below:

model = SomeNet(num_classes)
optimizer_fea = torch.optim.SGD({'params': model.encoder.parameters()})
optimizer_cls1 = torch.optim.SGD({'params': model.cls1.parameters()})
optimizer_cls2= torch.optim.SGD({'params': model.cls2.parameters()})

for data, label1, label2, label3 in data_loader:
    output1, output2 = model(data)
    loss1 = criterion1(output1, label1)
    loss2 = criterion2(output2, label2)
    loss3 = criterion2(output2, label3)




I have 2 questions for above code:

  1. If I use loss3.backward without retain_graph=True, then most of computational graph is freed. However, loss3 does not involve the node of cls1 in computational graph, so I guess there is some graph is not freed. Is there any convenient way to free all computational graph after I’ve called loss3.backward()?

  2. The above code is tediously long. Is there any convenient way to only compute gradient for a specific layer of model?

Seriously, does no one know how to do this?

I want to use loss1 to update model.encoder and model.cls1, which is simple. I can just use loss1.backward().
But for loss2 and loss3 the situation is different. loss2 and loss3 is computed on the exact same computational graph. Yet I want loss2 only update model.cls2 and loss3 only update model.encoder.

Can I do that in an elegant way?

I think torch.autograd.grad may do the job. But I’m not sure how to implement it.