MSELoss and torch.max

I have a model which has 10 outputs. I want to select the index with the maximum value and do MSEloss with my label. However, it appears that if I do _, max_index = torch.max(raw_result, 1), then max_index.requires_grad will be False, and that prevents the code from loss.backward(). Does anyone have experience on how to do so?

Did you try torch.argmax ?


x = torch.randn(10, 20)
x.requires_grad = True
output = torch.argmax(x, 1)
output.requires_grad # False

On the other hand, it seems hard to propagate the gradient back if you look at this post

As proposed here it should be a possibility to use softmax because the hard indexing is not differentiable