Multi-task learning with adaptive weights for task losses

I am trying to reproduce this recent paper:
GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

The idea is to normalize gradients across different tasks, and authors used this idea to learn weights for the corresponding losses for each task adaptively.
I have my main model class for two arbitrary regression tasks (one shared layer and two task specific towers) as follows:

class MTLnet(nn.Module):
    def __init__(self):
        super(MTLnet, self).__init__()
        self.sharedlayer = nn.Sequential(
            nn.Linear(feature_size, shared_layer_size),
            nn.ReLU(),
            nn.Dropout()
        )
        self.tower1 = nn.Sequential(
            nn.Linear(shared_layer_size, tower_h1),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(tower_h1, tower_h2),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(tower_h2, output_size)
        )
        self.tower2 = nn.Sequential(
            nn.Linear(shared_layer_size, tower_h1),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(tower_h1, tower_h2),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(tower_h2, output_size)
        )        

    def forward(self, x):
        h_shared = self.sharedlayer(x)
        out1 = self.tower1(h_shared)
        out2 = self.tower2(h_shared)
        return out1, out2

MTL = MTLnet()
opt1 = torch.optim.Adam(MTL.parameters(), lr=LR)
loss_func = nn.MSELoss()

And two weights for two losses as follows:

 Weightloss1 = torch.tensor(torch.FloatTensor([1]), requires_grad=True)
 Weightloss2 = torch.tensor(torch.FloatTensor([1]), requires_grad=True)
 params = [Weightloss1, Weightloss2]
 opt2 = torch.optim.Adam(params, lr=LR)
 Gradloss = nn.L1Loss()

And the code to run goes like this for mini-batches:

alph = 0.16 
    for minibatch in minibatches:
        XE, YE1, YE2  = minibatch 
        
        # Getting weighted losses for two tasks 
        Yhat1, Yhat2 = MTL(XE)
        l1 = params[0].data*loss_func(Yhat1, YE1.view(-1,1))    
        l2 = params[1].data*loss_func(Yhat2, YE2.view(-1,1))
        loss = torch.add(l1,l2)/2
        
        # for the first epoch with no l0
        if epoch == 0:
            l0 = loss.data        
        
        opt1.zero_grad()
        opt2.zero_grad()
        
        loss.backward(retain_graph=True)   
        
        # Getting gradients of the first layers in each tower and calculate their l2-norms 
        par = list(MTL.parameters())
        G1 = torch.tensor(par[2].grad.norm(2), requires_grad=True)
        G2 = torch.tensor(par[8].grad.norm(2), requires_grad=True)
        G_avg = torch.add(G1,G2)/2
        
        # Calculating relative losses 
        lhat1 = l1/l0
        lhat2 = l2/l0
        lhat_avg = torch.add(lhat1,lhat2)/2
        
        # Calculating relative inverse training rates for tasks 
        inv_rate1 = lhat1/lhat_avg
        inv_rate2 = lhat2/lhat_avg
        
        # Calculating the gradient loss according to Eq. 2 in the GradNorm paper
        Lgrad = torch.add(Gradloss(G1, G_avg*(inv_rate1)**alph),Gradloss(G2, G_avg*(inv_rate2)**alph))
        Lgrad.backward()
        
        # Updating loss weights 
        opt2.step()
        # Updating the model weights
        opt1.step()
        
        # Renormalizing the losses weights to make their summation equals to 2
        params = [2*params[0].data/(params[0].data+params[1].data), 
                  2*params[1].data/(params[0].data+params[1].data)]

The problem is although Weightloss1 and 2 are required_grad=True and G1 and G2 also have gradients (checked using .grad), G_avg and Lgrad have None for their gradients, and as a result, loss weights in the “params” remain unchanged.

2 Likes

Interested! Have you find the solution

1 Like

Take a look at my github repository here: https://github.com/hosseinshn/GradNorm

1 Like

Have you solved this?
By the way, i find that some of the details in your code are different from the ones in the paper. In the paper, l0 are the initial costs for the two tasks respectively, but in your code, l0 is an averaged loss.

1 Like

Thanks for mentioning the error and sorry for the delayed response. I added a new notebook (v10) and updated the l0.

1 Like