I have an MLP and trying to predict 2 values from the MLP (so i’m doing a multi task learning). these values have two losses of different scales (one is in the tens and the other is in 100,000s). so I cant simply add the losses together before back propagating. I’m currently trying to learn a weight so I can use it to balance the losses but I dont know how do go about that. Below is a brief summary of the code
class MLP(nn.Module):
def __init__(self, input_dim, output_dim):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_dim, 100)
self.out1 = nn.Linear(100, output_dim - 1)
self.out2 = nn.Linear(100, 1)
self.relu = nn.ReLU()
#create a learnable parameter that is used to weight the two losses
self.weight = nn.Parameter(torch.ones(1, requires_grad=True))
def forward(self, x):
out1 = self.out1(self.relu(self.fc1(x)))
out2 = self.out2(self.relu(self.fc1(x)))
return out1, out2, self.weight
#in train method (assume we have the target values and data)
predicted_one, predicted_two, weight = model(data)
loss_one = nn.Functional.mse_loss(predicted_one, target_one)
loss_two = nn.Functional.mse_loss(predicted_two, target_two)
total_loss = loss_one + weight * loss_two
I dont think this is how to go about it because the the weight doesn’t learn. Please Kindly assist.