Hindsight / Multi-Output Loss


I have built a model that can output multiple solutions to the same classification problem, i.e., its spreads it bets. The idea is to take the best solution out of various proposed ones. The loss that papers [1,2,3] implement for such a network is called hindsight loss, and the idea is to just take the minimum loss for all the possible outputs you got. My current implementation looks as follows:

def hindsight_loss(output, labels):
    if (len(labels.size()) == 1):
        # we need to reshape our labels from [0,0,1,0] to [[0], [0], [1], [0]]
        labels = labels.resize_(labels.size()[0], 1)

    labels = labels.to(torch.float32)

    ce_func = nn.BCELoss()
    loss = ce_func(output[:, 0, None], labels)
    for res in range(1, output.shape[1]):
        new_loss = ce_func(output[:, res, None],labels)
        loss = torch.min(loss, new_loss)

    return loss

However, the current implementation leads to the fact that during backpropagation, only the one output of my set of outputs is getting improved. I suppose this is because during backprop, the weights for all the other ways to generate outputs are not changed. However, I do not think this is the idea of the authors, but I could not find anything regarding the hindsight loss in PyTorch (or any other framework, for that matter).

Does anybody here know whether something is wrong with my implementation and/or understanding of the issue? Because the way it is right now, it does not make sense to learn to give multiple outputs, because only one output is being improved anyway.

Thank you so much.


[1] Multiple Choice Learning:
Learning to Produce Multiple Structured Outputs
[2] Interactive Image Segmentation with Latent Diversity
[3] Combinatorial Optimization with Graph
Convolutional Networks and Guided Tree Search