If you want to get the maximal indices for each row, you could just call .argmax(1)
:
x = torch.randn(10, 15)
x.argmax(1)
If you want to get the maximal indices for each row, you could just call .argmax(1)
:
x = torch.randn(10, 15)
x.argmax(1)