Torch.argmax - I want more than just one value

I have a pytorch tensor as output language model (GTP2) and I use argmax to get index with the highest probability.

predicted_index = torch.argmax(predictions_2[0, -1, :]).item()

But I need all the indexes sorted by probability.

How do I do this in pytorch?

#####################################################
print (“predicted_index:”,predicted_index)
print (“predictions_2[0, -1, :]:”,predictions_2[0, -1, :])

I get:
predicted_index: 484
predictions_2[0, -1, :] field is: tensor([-122.9283, -124.4627, -128.4069, …, -131.5974, -128.7110,
-125.7269], device=‘cuda:0’)

1 Like

Have a look at @ptrblck’s answer:

you can get the topk values and indices using torch.topk() API.

1 Like

It, works now. Thanks very much.

predicted_k_indexes = torch.topk(predictions_2[0, -1, :],k=3)
prk_0 = predicted_k_indexes[0]
prk_1 = predicted_k_indexes[1]
for item11 in prk_1:
print (item11.item())

output:
484
523
35075

1 Like

Can you tell me how to store the top 3 max indexes values in a List or Tensor

Double post from here.