Custom loss function based on NLL + MSE

Hi !

As implied in the title, I tried to define a custom loss function for a simple LSTM model by combining NLL loss and MSE loss. Since the output of my model is a probability distribution over several classes, I take the argmax of it in order to be able to compute the MSE loss.
But I run into this error :

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

It seems that the argmax() function changes the requires_grad attribute of a tensor to False. Is there a way to ensure this doesn’t happen ? Or is this because differentiating max() is mathematically not very defined ?

        pred = model(x)[0]
        flat_pred = torch.reshape(pred, (pred.size(0) * pred.size(1), pred.size(2)))
        #flattening for batch loss computation
        loss =  mse(torch.argmax(flat_pred, dim = 1).double(), y.flatten().double()) #+  nll(flat_pred, y.flatten())
        total_loss.append(loss)

If I print the requires_grad, I see that it is true for the flat_pred tensor, but torch.argmax(flat_pred, dim = 1).double().requires_grad is False

I have to do it because argmax returns a tensor of type Long and MSE is not defined for these kind of tensors. I will try rewritting the train loop, maybe there another insidious bug that I didn’t spot.

Sorry, I mixed argmax with max. @JohnCwok is right, the index returned by argmax does not have gradients. In this case, if ur output is onehot, u don’t have to use the argmax, just feed it into the MSELoss.