Store index of max values

How can find the top 3 max values in the tensor and how can I store the index values in one variable.

The topk function probably does something similar to what you want: torch.topk ā€” PyTorch 1.8.1 documentation

1 Like

Thank you @eqy

But my question is about how to pick the top 3 max values of the indexes and store that indexes in a variable . This is a solution I found

frame = sorted(range(len(mul_reward)), key=lambda i: mul_reward[i], reverse=True)[:4]
print(frame)
con_frame = torch.Tensor(frame)

Iā€™m not sure I understand; what happens when you run
_, frame = torch.topk(mul_reward, 3) ?