Hey, I am trying to implement the following loss function.
def calculate_asn_loss_t_test_f(original_tensor, reconstructed_tensor, beta = 0.8, lambda_p = 1, lambda_n = 0.01):
residual_tensor = torch.abs(reconstructed_tensor - original_tensor).float().cuda()
residual_tensor = residual_tensor.sum(dim = [1, 2, 3]).float().cuda()
residual_l = original_tensor.size(0)
pos_residual = residual_tensor[: residual_l//2].float().cuda()
neg_residual = residual_tensor[residual_l//2 :].float().cuda()
pos_mean = pos_residual.mean().cuda().float()
neg_mean = neg_residual.mean().cuda().float()
pos_std = pos_residual.std().cuda()
neg_std = neg_residual.std().cuda().float()
a = beta - pos_mean.float()
b = neg_mean + lambda_p * (pos_std ** 2).float() + lambda_n * (neg_std ** 2).float()
if a > b.mean():
print("RETURNING loss a: ", a)
return a + epsilon
else:
print("RETURNING loss b: ", b)
return b.mean() + epsilon
But it is giving me the following error:
RuntimeError: Function ‘StdBackward0’ returned nan values in its 0th output.
Any help?