Weighted loss during ensemble

I am ensembing two models with mean pooling but also want to weight the loss of each seperate model at the same time so the less accurate model will contribute less to the final prediction.

Here is a simple example of what I am trying to achieve.

class WeightLoss(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.n = n
        self.param = nn.Parameter(torch.empty(n), requires_grad=True)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.constant_(self.param, 1 / self.n)

    def forward(self, x):
        # make sure the params sum to 1
        return torch.mul(x, F.softmax(self.param))


class Net2(nn.Module):
    def __init__(self, modelA, modelB):
        super(Net2, self).__init__()
        self.modelA = modelA
        self.modelB = modelB

        self.fc1 = nn.Linear(10, 8)
        self.fc2 = nn.Linear(10, 8)

        self.weight_loss = WeightLoss(2)
    
    def loss(self, x, y):
        return F.cross_entropy(x, y)
    
    def forward(self, x, y):
        # individual outputs
        output1 = self.fc1(self.modelA(x))
        output2 = self.fc2(self.modelB(x))
        # mean pooling
        final_output = torch.mean(torch.stack([output1, output2]), dim=0)
        final_loss = self.loss(final_output, y)
        # weighted loss for individual models
        loss1 = self.loss(output1, y)
        loss2 = self.loss(output2, y)
        weighted_loss = self.weight_loss(torch.stack([loss1, loss2]))
        weighted_loss = torch.sum(weighted_loss)

        # combine the 2 losses
        final_loss = torch.add(final_loss, weighted_loss)

        return final_loss

With the following architecture, final_loss do decrease over time but the weights for weighing the losses for each model do not change over time or by very small margin.

Have I done anything inappropriate?

Your code seems to work generally and reduces the weight for the larger loss to decrease the overall global loss:

class WeightLoss(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.n = n
        self.param = nn.Parameter(torch.empty(n), requires_grad=True)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.constant_(self.param, 1 / self.n)

    def forward(self, x):
        # make sure the params sum to 1
        return torch.mul(x, F.softmax(self.param, dim=0))


class Net2(nn.Module):
    def __init__(self, modelA, modelB):
        super(Net2, self).__init__()
        self.modelA = modelA
        self.modelB = modelB

        self.fc1 = nn.Linear(10, 8)
        self.fc2 = nn.Linear(10, 8)

        self.weight_loss = WeightLoss(2)
    
    def loss(self, x, y):
        return F.cross_entropy(x, y)
    
    def forward(self, x, y):
        # individual outputs
        output1 = self.fc1(self.modelA(x))
        output2 = self.fc2(self.modelB(x))
        # mean pooling
        final_output = torch.mean(torch.stack([output1, output2]), dim=0)
        final_loss = self.loss(final_output, y)
        # weighted loss for individual models
        loss1 = self.loss(output1, y)
        loss2 = self.loss(output2, y)
        print("loss1 {}, loss2 {}".format(loss1.item(), loss2.item()))
        weighted_loss = self.weight_loss(torch.stack([loss1, loss2]))
        weighted_loss = torch.sum(weighted_loss)

        # combine the 2 losses
        final_loss = torch.add(final_loss, weighted_loss)

        return final_loss
    
model = Net2(nn.Linear(10, 10), nn.Linear(10, 10))
x = torch.randn(8, 10)
target = torch.randint(0, 8, (8,))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1000):
    optimizer.zero_grad()
    loss = model(x, target)
    loss.backward()
    optimizer.step()
    print("epoch {}, loss {}, weights {}".format(
        epoch, loss.item(), model.weight_loss.param))

Output:

...
loss1 0.017472071573138237, loss2 0.008564327843487263
epoch 998, loss 0.020118102431297302, weights Parameter containing:
tensor([-0.0558,  1.0558], requires_grad=True)
loss1 0.017431674525141716, loss2 0.00854448415338993
epoch 999, loss 0.020070992410182953, weights Parameter containing:
tensor([-0.0559,  1.0559], requires_grad=True)

I don’t know if calculating the final_loss first and adding the weighted_loss is beneficial or not.