As implied in the title, I tried to define a custom loss function for a simple LSTM model by combining NLL loss and MSE loss. Since the output of my model is a probability distribution over several classes, I take the argmax of it in order to be able to compute the MSE loss.
But I run into this error :
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
It seems that the
argmax() function changes the
requires_grad attribute of a tensor to
False. Is there a way to ensure this doesn’t happen ? Or is this because differentiating max() is mathematically not very defined ?
pred = model(x)
flat_pred = torch.reshape(pred, (pred.size(0) * pred.size(1), pred.size(2)))
#flattening for batch loss computation
loss = mse(torch.argmax(flat_pred, dim = 1).double(), y.flatten().double()) #+ nll(flat_pred, y.flatten())
If I print the
requires_grad, I see that it is true for the
flat_pred tensor, but
torch.argmax(flat_pred, dim = 1).double().requires_grad is False
I have to do it because argmax returns a tensor of type
Long and MSE is not defined for these kind of tensors. I will try rewritting the train loop, maybe there another insidious bug that I didn’t spot.
Sorry, I mixed
max. @JohnCwok is right, the index returned by
argmax does not have gradients. In this case, if ur output is onehot, u don’t have to use the
argmax, just feed it into the