Weights are not being updated

Hi, I am implementing a toy example of multi-task learning with different weights for each task’s loss; my main model class looks like this:

class MTLnet(nn.Module):
    def __init__(self):
        super(MTLnet, self).__init__()
        self.weightloss1 = Variable(torch.FloatTensor([1]), requires_grad=True)
        self.weightloss2 = Variable(torch.FloatTensor([1]), requires_grad=True)
        self.sharedlayer = nn.Sequential(
            nn.Linear(feature_size, shared_layer_size),
            nn.ReLU(),
            nn.Dropout()
        )
        self.tower1 = nn.Sequential(
            nn.Linear(shared_layer_size, tower_h1),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(tower_h1, tower_h2),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(tower_h2, output_size)
        )
        self.tower2 = nn.Sequential(
            nn.Linear(shared_layer_size, tower_h1),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(tower_h1, tower_h2),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(tower_h2, output_size)
        )        

    def forward(self, x):
        h_shared = self.sharedlayer(x)
        out1 = self.tower1(h_shared)
        out2 = self.tower2(h_shared)
        return out1, out2

But, those two weights are not among the parameters of the model object and as a result not being updated:

MTLnet(
  (sharedlayer): Sequential(
    (0): Linear(in_features=100, out_features=64, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5)
  )
  (tower1): Sequential(
    (0): Linear(in_features=64, out_features=32, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5)
    (3): Linear(in_features=32, out_features=16, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5)
    (6): Linear(in_features=16, out_features=1, bias=True)
  )
  (tower2): Sequential(
    (0): Linear(in_features=64, out_features=32, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5)
    (3): Linear(in_features=32, out_features=16, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5)
    (6): Linear(in_features=16, out_features=1, bias=True)
  )
)

Any thoughts?

Thanks.

The weightlosses are not properly registered as Parameters.
Wrap the tensors into nn.Parameters and it should work:

class MTLnet(nn.Module):
    def __init__(self):
        super(MTLnet, self).__init__()
        self.weightloss1 = torch.nn.Parameter(torch.FloatTensor([1]))
        self.weightloss2 = torch.nn.Parameter(torch.FloatTensor([1]))
        ...

Also the usage of Variable is deprecated, since Variables and tensors were merged a while ago.

PS: I’ve formatted your code for readability reasons. You can add code snippets using three backticks ` :wink:

Thanks a lot! it worked.

Hello, Hossein,

Does “multi-task learning with different weights for each task’s loss” help MTL performance. I’m curious about how the weights can be learned? I study MTL recently.

Thanks

Hi Kai, in general I’d say yes but like many other aspects of ML and MTL, it depends on the problem. I implemented a paper related to the weighted-task MTL here: https://github.com/hosseinshn/GradNorm