I am trying to implement Multiple Hypothesis colorization in pytorch (https://arxiv.org/pdf/1606.06314.pdf) and original code is at :
(https://bitbucket.org/harislorenzo/colorization-for-image-compression/src/3c51ce221af73cd0232fc3aaf42b3ac8453c1ec2/src/caffe/layers/euclidean_loss_layer.cu?at=master&fileviewer=file-view-default)
My implementation is:
def multiple_hypothesis(net,x, target, K):
b,h,w,c = x.size()
branches_loss = torch.FloatTensor(K,b,h,w,c)
predictions = net(x) # net is returning 3 outputs from 3 different branch of the network
##Loss for each prediction at pixel level
for k in range(K):
branches_loss[k] = (predictions[k].data - target.data) ** 2
##Find min prediction at each pixel
values,index = torch.min(branches_loss,0)
grad = []
for k in range(K):
grad.append(Variable(torch.mul(branches_loss[k], (index == k).float())))
grad[0]
will be of size => [batch_size, height, width, channel]
How to back propagate the loss in each branch?