Suppose I have a 3d Tensor x
, and I run itorch.topk(x, k=2, dim=0)[1]
to retrieve the indices of the first two max values over the 0th dimension.
Then I want to use those indices to index the tensor to assign a value, however I am not able to define the code to perform the correct advanced indexing, The only thing I was able to do is:
_, H, W = x.shape
inds = torch.topk(x, k=self.k, dim=1)[1]
h = torch.arange(H).long().repeat(W, 1).transpose(0, 1).contiguous().view(-1)
w = torch.arange(W).long().repeat(H)
for i in range(self.k):
x[list(inds[i, :, :].view(-1)), list(h), list(w)] = val
however this is very inefficient. Does someone know some more efficient version?