# Advanced indexing with torch.topk

Suppose I have a 3d Tensor `x`, and I run `itorch.topk(x, k=2, dim=0)` to retrieve the indices of the first two max values over the 0th dimension.

Then I want to use those indices to index the tensor to assign a value, however I am not able to define the code to perform the correct advanced indexing, The only thing I was able to do is:

``````    _, H, W = x.shape
inds = torch.topk(x, k=self.k, dim=1)

h = torch.arange(H).long().repeat(W, 1).transpose(0, 1).contiguous().view(-1)
w = torch.arange(W).long().repeat(H)
for i in range(self.k):
x[list(inds[i, :, :].view(-1)), list(h), list(w)] = val
``````

however this is very inefficient. Does someone know some more efficient version?

3 Likes

Would this work:

``````x = torch.randn(3, 10, 10)
idx = torch.topk(x, k=2, dim=0)
x.scatter_(0, idx, 100)
print(x)
``````
10 Likes

It, works now. Thanks very much.

predicted_k_indexes = torch.topk(predictions_2[0, -1, :],k=3)
prk_0 = predicted_k_indexes
prk_1 = predicted_k_indexes
for item11 in prk_1:
print (item11.item())

output:
484
523
35075

can you please explain the purpose of the third argument in the function?

``````x.scatter_(0, idx, >>100<<)
``````

The third argument is the `src` and represents the value which will be scattered into `x` in the `0` dimension defined by the indices in `idx`.
In my code snippet the value `100` will be written to `x`.

1 Like