Indices from `max` is not differentiable in custom loss function

I’m training a network using a custom loss function. And it gives a RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn. It seems that the indices from max is not differentiable which i have to use for loss function.
Does anyone have idea to fix it?
This is my code:

class MyLoss(nn. Module):
    def __init__(self):
        super(MyLoss, self).__init__()
        
    def forwad(self, d, m_label, n_label):
        # d size : [N, C, H, W], from softmax output
        # m_label size : [N, 1, H, W]
        # n_label size : [N, 1, H, W]
        
        # get the predicted class
        pred = d.max(1)[1]
        pred = pred.unsqueeze(1)
        n_pred = 2.7*pred + m_label
        
        # calc image grad along x, y
        m_dx = m_label[:, :, 1:, :] - m_label[:, :, :-1, :]
        m_dy = m_label[:, :, :, 1:] - m_label[:, :, :, :-1]
        n_dx = n_pred[:, :, 1:, :] - n_pred[:, :, :-1, :]
        n_dy = n_pred[:, :, :, 1:] - n_pred[:, :, :, :-1]
        
        # loss1: minimize the difference of grad of m_label and n_pred
        loss1 = torch.mean(((n_dx - m_dx).abs().sum() 
                              + (n_dy - m_dy).abs().sum()))
        
        # loss2: minimize the difference of n_pred and n_label
        loss2 = nn.functional.l1_loss(n_pred, n_label)
        return loss1 + loss2

Hi Lee!

First the “why” of your error:

The pytorch framework uses various gradient-descent methods
to train your network – that is to minimize your loss function
with respect to your model’s parameters. For this to work, your
loss function must be a (usefully) differentiable function of the
model parameters.

As you have seen, the argmax() function (“the indices from
max”) is not differentiable, so you can’t use it in your loss function.

As to how to fix your issue:

The short answer is that you have to modify your loss function
to be (usefully) differentiable.

I don’t understand the concept behind your loss functions, so I
don’t have suggestions for you.

What type of problem are you trying to solve?

What are the inputs to your network, and what do those inputs
mean?

What are the meanings of m_label and n_label?

What, conceptually, determines whether your model has made
a good prediction, and how is your loss function trying to measure
that?

Good luck!

K. Frank

Thanks for reply.
I’m trying to solve a muti classification problem called phase unwrapping using a u-net. The final layer is 1x1 conv followed by softmax.
There is a relationship between image m_label and n_label, n_label = m_label + constant*k, in which in which m_label is network input , k is predicted class (an integer matrix) from softmax output.
To make a good prediction, i minimize:

  1. the diffference of grad of m_label and n_pred
  2. the difference of n_label and n_pred (calc by n_pred = m_label + constant*k_pred)
  3. the cross entry loss between k_pred and k_label (calc by round((n_label - n_label)/constant) ), not mentioned in code