I am trying to use a reference matrix to slice another matrix with more dimensions:
# The tensor that we use for indexing reference = torch.rand(bsize, nfeat) index = torch.argsort(reference)[:, :bottomk] # The target tensor that we want to slice target = torch.rand(bsize, nsample, nfeat)
How can I perform slicing using
index? I want to automatically broadcast the second dimension. Can you suggest me how to do it?
# The for-loop version. But I want a more elegant and faster version # that does not loop over batch for i in range(bsize): target[i, :, index[i, :]] = 0