Advanced indexing with torch.topk

Suppose I have a 3d Tensor x, and I run itorch.topk(x, k=2, dim=0)[1] to retrieve the indices of the first two max values over the 0th dimension.

Then I want to use those indices to index the tensor to assign a value, however I am not able to define the code to perform the correct advanced indexing, The only thing I was able to do is:

    _, H, W = x.shape
    inds = torch.topk(x, k=self.k, dim=1)[1]

    h = torch.arange(H).long().repeat(W, 1).transpose(0, 1).contiguous().view(-1)
    w = torch.arange(W).long().repeat(H)
    for i in range(self.k):
        x[list(inds[i, :, :].view(-1)), list(h), list(w)] = val

however this is very inefficient. Does someone know some more efficient version?

3 Likes

Would this work:

x = torch.randn(3, 10, 10)
idx = torch.topk(x, k=2, dim=0)[1]
x.scatter_(0, idx, 100)
print(x)
10 Likes

It, works now. Thanks very much.

predicted_k_indexes = torch.topk(predictions_2[0, -1, :],k=3)
prk_0 = predicted_k_indexes[0]
prk_1 = predicted_k_indexes[1]
for item11 in prk_1:
print (item11.item())

output:
484
523
35075