Custom loss function value does not decrease

I have the following customized loss function which I want to minimize.

def compute_clamped_dist(target_coordinates_tensor, observed_coordinates_tensor):
    euclidean_dist = 0.5 * torch.sqrt(torch.sum((target_coordinates_tensor-observed_coordinates_tensor)**2, dim =1))
    return torch.clamp(euclidean_dist, max=1.0)

def f1_loss(target_tensor, observed_tensor):
    classification_tensor = target_tensor[:, 4:]
    observed_tensor_12 = observed_tensor[:, :2]
    observed_tensor_34 = observed_tensor[:, 2:]
    target_tensor_12 = target_tensor[:, :2]
    distance_o12_t12 = compute_clamped_dist(target_tensor_12, observed_tensor_12)
    distance_o34_t12 = compute_clamped_dist(target_tensor_12, observed_tensor_34)
    third_component = 0.5*(1+torch.min(distance_o12_t12, distance_o34_t12))
    output = classification_tensor[:, 0] + torch.mul(classification_tensor[:, 1], distance_o12_t12)\
                                         + torch.mul(classification_tensor[:, 2], third_component)
    return output.mean()

The target tensor is of size (N * 7) and the observation tensor is of size (N * 4). I want to make the observation tensor as similar to the first 4 columns of the target tensor. The last three columns of the target tensor are constants tensors. I am using my custom loss function as follows:

 for batch_idx, sample_batched in enumerate(one_spot_train_dataloader):
            self.one_spot_optimizer.zero_grad()
            data = sample_batched['image']
            target = sample_batched['label_meta'].float()[:, 1:]
            if self.use_gpu:
                data = data.cuda(self.cuda_name)
                target = target.cuda(self.cuda_name)
                constant_classification_tensor = constant_classification_tensor.cuda(self.cuda_name)
            output = self.one_spot_network(data)
            constant_batch_classification_tensor = constant_classification_tensor[batch_idx * \
                self.batch_size_train: (batch_idx+1) * self.batch_size_train, :]
            final_target = torch.cat((target,constant_batch_classification_tensor), dim = 1)
            loss = f1_loss(final_target, output)
            loss.backward()
            self.one_spot_optimizer.step()

The optimizer does not reduce the value of my loss function but when I try mse_loss the value is minimized by the optimizer. Can anybody help me understand why this is happening? I wrote my loss function in torch but it is not getting minimized.

You usually don’t want to use a clamp operation in your loss function because the derivative in the saturation case is zero. i.e. no learning happens if it’s saturated.

1 Like

@colesbury is that the reason that the loss doesn’t get minimized at all?

It might be. You should inspect the gradients. Use retain_grad() to store the grad attribute on intermediate variables and look at the value.

https://pytorch.org/docs/stable/autograd.html?highlight=grad#torch.Tensor.retain_grad

1 Like

@colesbury Thanks for your help. I tried your suggestion and when I print output.grad it is only consisting of zeros and nans. Per your comment, I removed the clamp operation from my loss function and looked at the grad again and it looks like this time I see some numbers and some nans. More importantly, this time the optimizer is able to reduce my loss function. So I think your intuition is correct and clamp messes up everything.

Thanks again for your help.

@mohsenkiskani
Can you give me more details? Where did you write the output.grad? In loss function or training block?