My custom loss never goes backward

I am implementing a custom loss for my multi-class classification application.
For every y_true, there is a corresponding P_true which represents a list of values.
For every y_pred, after some reformat, using it as the key of an external dictionary, can look up another list of values (with equal length of P_true).
I’d like to calculate the MSE loss between these two lists of values (true_samples, pred_samples), and use the result as the output of this custom loss function.

def custom_loss_fn(y_pred, y_true, P_true):
    y_pred_softmax = torch.log_softmax(y_pred, dim=1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim=1)

    true_samples = Variable(P_true, requires_grad=True).to(device)

    pred_cluster_names = [idx2class[tag.squeeze().tolist()]
                          for tag in y_pred_tags]
    pred_samples = Variable(
        torch.FloatTensor(
            [name_samples_mapping[pred_name]
                for pred_name in pred_cluster_names]
        ), requires_grad=True).to(device)

    output = mseloss(true_samples, pred_samples)

    return output

But I found that the training loss from each epoch never changed, suggesting something wrong with my computation graph connectivity that autograd failed. Could anyone provide help? Thanks!

Edit: true_samples tensors have requires_grad=True, pred_samples tensors have grad_fn=<CopyBackwards>, and output for training loss has grad_fn=<MeanBackward0>.

Hi,

The problem is that the argmax operation that you do on y_pred at the beginning is not differentiable. It returns integer indices.

So no gradient will be able to flow back to y_pred and to your network above :confused:

Thank you for your insight! I am new to PyTorch and didn’t realize argmax is the root cause of the problem and have been doubting if my dictionary lookups broke the connectivity since they were not wrapped in Variables.

I read your answer to another similar problem (link here), would you mind providing some further suggestions/solutions for work around this argmax issue and make my backprop working? Thanks!

As mentioned in the other thread, replacing this kind of argmax is a area of research in itself almost :slight_smile:

I guess the simplest you can try here is to use a weighted version based on the softmax value. So that the predicted value is a weighted value of different dictionary entries.
But this might not work…

I think essentially my custom loss function is more like a ranking-purpose function than a classification one. Saying this I mean: if the y_pred looks up pred_samples which can satisfy less mseloss when being compared to true_samples, then y_pred is assumed to stand out in the candidate list among all possible y classes.

Also, I read some other resources, do you think if these materials can lead me in the right direction?

Thanks a lot for your time and reply.