Hello, I’m having a tensor X of shape (Batch, Nodes, Features), and another one scores of shape (Batch, Node_score). I would like to take biggest x nodes with the highest score from each batch.
I managed to make it work without a batch, but I do not know how to extend it to take batches into account. Here is the version without:
X = torch.from_numpy(np.array([[1, 2], [3, 4], [5, 6] ])) scores = torch.from_numpy(np.array([10, 2, 15])) values, idx = torch.topk(scores, 2, dim=-1) new_X = X[idx, :]
Thank you in advance.