I have a model which has 10 outputs. I want to select the index with the maximum value and do MSEloss with my label. However, it appears that if I do _, max_index = torch.max(raw_result, 1)
, then max_index.requires_grad
will be False
, and that prevents the code from loss.backward()
. Does anyone have experience on how to do so?
Did you try torch.argmax ?
yeah
x = torch.randn(10, 20)
x.requires_grad = True
output = torch.argmax(x, 1)
output.requires_grad # False
On the other hand, it seems hard to propagate the gradient back if you look at this post
As proposed here it should be a possibility to use softmax because the hard indexing is not differentiable