Based on the documentation, the tensor.argmax should returns the index of the first occurrence of multiple largest values.
But when the tensor contains all same values, it returns the index of the last value.
For example, torch.tensor([0, 0, 0, 0]).argmax(dim=0)
, outputs is tensor(3)
. Shouldn’t it be 0?