Indexing with topk

Hello, I’m having a tensor X of shape (Batch, Nodes, Features), and another one scores of shape (Batch, Node_score). I would like to take biggest x nodes with the highest score from each batch.

I managed to make it work without a batch, but I do not know how to extend it to take batches into account. Here is the version without:

X = torch.from_numpy(np.array([[1, 2], [3, 4], [5, 6] ]))
scores = torch.from_numpy(np.array([10, 2, 15]))
values, idx = torch.topk(scores, 2, dim=-1)
new_X = X[idx, :]

Thank you in advance.