How to backpropagate pixel level loss in different branches of network?

I am trying to implement Multiple Hypothesis colorization in pytorch (https://arxiv.org/pdf/1606.06314.pdf) and original code is at :
(https://bitbucket.org/harislorenzo/colorization-for-image-compression/src/3c51ce221af73cd0232fc3aaf42b3ac8453c1ec2/src/caffe/layers/euclidean_loss_layer.cu?at=master&fileviewer=file-view-default)

My implementation is:

def multiple_hypothesis(net,x, target, K):
    b,h,w,c = x.size()
    branches_loss = torch.FloatTensor(K,b,h,w,c)
    predictions = net(x) # net is returning 3 outputs from 3 different branch of the network
    ##Loss for each prediction at pixel level
    for k in range(K): 
        branches_loss[k] = (predictions[k].data - target.data) ** 2
    ##Find min prediction at each pixel
    values,index = torch.min(branches_loss,0)
    grad = []
    for k in range(K):
        grad.append(Variable(torch.mul(branches_loss[k], (index == k).float())))

grad[0] will be of size => [batch_size, height, width, channel]
How to back propagate the loss in each branch?

So you are using a branch for each pixel that has a minimum pixel-wise loss.

I think just min operation can make code much simpler

def multiple_hypothesis(net, x, target):                                 # Actually, you don't have to put K here.
    predictions = net(x)
    branches_loss = [(p - target).pow(2) for p in predictions]    # Because predictions already has K elements.
    loss_stack = torch.stack(branches_loss, 0)    # This will return K x b x h x w x c 'Variable'
    loss_min = loss_stack.min(0)[0]                    # The first element for values, second element for indices.

I think you can backpropate the loss by averaging loss_min and call backward().

1 Like

Thanks Sanghyun. It works

1 Like