Broadcasting Index

I am trying to use a reference matrix to slice another matrix with more dimensions:

# The tensor that we use for indexing
reference = torch.rand(bsize, nfeat)
index = torch.argsort(reference)[:, :bottomk]
# The target tensor that we want to slice
target = torch.rand(bsize, nsample, nfeat)

How can I perform slicing using index? I want to automatically broadcast the second dimension. Can you suggest me how to do it?

# The for-loop version. But I want a more elegant and faster version 
# that does not loop over batch
for i in range(bsize):
    target[i, :, index[i, :]] = 0

Why target is size of bsize * nsample * nfeat. How do you want to slice when your reference is size of bsize * nfeat. Please give an a example in detail. Probably I can help with it.

I edited the question to clarify your question.

can’t say that gather() is fast, but following should be a loop-less equivalent for reading

target.gather(-1,index.unsqueeze(1).expand(-1,nsamples,-1))

for writing, look at Tensor.scatter_ or Tensor.masked_fill_, expanding index/mask similarly

1 Like