So you are using a branch for each pixel that has a minimum pixel-wise loss.
I think just min operation can make code much simpler
def multiple_hypothesis(net, x, target): # Actually, you don't have to put K here.
predictions = net(x)
branches_loss = [(p - target).pow(2) for p in predictions] # Because predictions already has K elements.
loss_stack = torch.stack(branches_loss, 0) # This will return K x b x h x w x c 'Variable'
loss_min = loss_stack.min(0) # The first element for values, second element for indices.
I think you can backpropate the loss by averaging loss_min and call backward().