I am trying to use a reference matrix to slice another matrix with more dimensions:
# The tensor that we use for indexing
reference = torch.rand(bsize, nfeat)
index = torch.argsort(reference)[:, :bottomk]
# The target tensor that we want to slice
target = torch.rand(bsize, nsample, nfeat)
How can I perform slicing using index? I want to automatically broadcast the second dimension. Can you suggest me how to do it?
# The for-loop version. But I want a more elegant and faster version
# that does not loop over batch
for i in range(bsize):
target[i, :, index[i, :]] = 0
Why target is size of bsize * nsample * nfeat. How do you want to slice when your reference is size of bsize * nfeat. Please give an a example in detail. Probably I can help with it.