How to backpropagate pixel level loss in different branches of network?

I am trying to implement Multiple Hypothesis colorization in pytorch ( and original code is at :

My implementation is:

def multiple_hypothesis(net,x, target, K):
    b,h,w,c = x.size()
    branches_loss = torch.FloatTensor(K,b,h,w,c)
    predictions = net(x) # net is returning 3 outputs from 3 different branch of the network
    ##Loss for each prediction at pixel level
    for k in range(K): 
        branches_loss[k] = (predictions[k].data - ** 2
    ##Find min prediction at each pixel
    values,index = torch.min(branches_loss,0)
    grad = []
    for k in range(K):
        grad.append(Variable(torch.mul(branches_loss[k], (index == k).float())))

grad[0] will be of size => [batch_size, height, width, channel]
How to back propagate the loss in each branch?

So you are using a branch for each pixel that has a minimum pixel-wise loss.

I think just min operation can make code much simpler

def multiple_hypothesis(net, x, target):                                 # Actually, you don't have to put K here.
    predictions = net(x)
    branches_loss = [(p - target).pow(2) for p in predictions]    # Because predictions already has K elements.
    loss_stack = torch.stack(branches_loss, 0)    # This will return K x b x h x w x c 'Variable'
    loss_min = loss_stack.min(0)[0]                    # The first element for values, second element for indices.

I think you can backpropate the loss by averaging loss_min and call backward().

1 Like

Thanks Sanghyun. It works

1 Like