I want to convert an image from LAB space to RGB space, and do backward to calculate the gradient of original LAB. But it seems that backward stuck! (forward is okay)! Please give some hints.
def tensor_lab2rgb(input):
"""
n * 3* h *w
"""
input_trans = input.transpose(1, 2).transpose(2, 3) # n * h * w * 3
L, a, b = input_trans[:, :, :, 0:1], input_trans[:, :, :, 1:2], input_trans[:, :, :, 2:]
y = (L + 16.) / 116.
x = (a / 500.) + y
z = y - (b / 200.)
neg_mask = z.data > 0.2068966 if isinstance(z, Variable) else z > 0.2068966
z[neg_mask] = 0
xyz = torch.cat((x, y, z), dim=3)
mask = xyz.data > 0.2068966 if isinstance(xyz, Variable) else xyz > 0.2068966
mask_xyz = xyz.clone()
mask_xyz[mask] = torch.pow(xyz[mask], 3.)
mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.) / 7.787
scale_mask_xyz = torch.cat((mask_xyz[:, :, :, [0]] * 0.95047,
mask_xyz[:, :, :, [1]],
mask_xyz[:, :, :, [2]] * 1.08883), dim=3)
if isinstance(scale_mask_xyz, Variable):
rgb_trans = torch.mm(scale_mask_xyz.view(-1, 3), Variable(torch.from_numpy(rgb_from_xyz).type_as(mask_xyz.data))).view(input.size(0), input.size(2), input.size(3), 3)
else:
rgb_trans = torch.mm(scale_mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(mask_xyz)).view(input.size(0), input.size(2), input.size(3), 3)
rgb = rgb_trans.transpose(2, 3).transpose(1, 2)
mask = rgb.data > 0.0031308 if isinstance(rgb, Variable) else rgb > 0.0031308
mask_rgb = rgb.clone()
mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055
mask_rgb[~mask] = rgb[~mask] * 12.92
mask_rgb[mask_rgb.data < 0] = 0
mask_rgb[mask_rgb.data > 1] = 1
return mask_rgb
The main call function is below:
img_lab_tensor_v = Variable(img_lab_tensor, requires_grad=True)
converted_rgb_img_v = tensor_lab2rgb(img_lab_tensor_v)
output_grad_v = torch.randn(converted_rgb_img_v.size())
converted_rgb_img_v.backward(output_grad_v)
lab_grad = img_lab_tensor_v.grad.data