How can find the top 3 max values in the tensor and how can I store the index values in one variable.
The topk
function probably does something similar to what you want: torch.topk ā PyTorch 1.8.1 documentation
1 Like
Thank you @eqy
But my question is about how to pick the top 3 max values of the indexes and store that indexes in a variable . This is a solution I found
frame = sorted(range(len(mul_reward)), key=lambda i: mul_reward[i], reverse=True)[:4]
print(frame)
con_frame = torch.Tensor(frame)
Iām not sure I understand; what happens when you run
_, frame = torch.topk(mul_reward, 3)
?