# Get top k indices/values of all rows

Hi, I’m looking to get the topk gradients of all rows, not topk of each row. For example, if I have a conv layer of shape [64, 64, 3, 3] and k=2, I only want 2 top values and their corresponding indices returned. Ultimately, I want a new tensor with a shape matching the dimensions of the original weight, with all elements zeroed out except the top k gradients.

Here is what I have now, which is quite with the mapping from flattened indices to expanded:

``````def get_topk_idx(grad):
path = [None]
while stack:
for path[-1], x in stack[-1]:
if isinstance(x, list):
stack.append(enumerate(x))
path.append(None)
else:
yield x, tuple(path)
break
else:
stack.pop()
path.pop()

idx_map = [] # flat_idx --> (val, idx)
for i in topk_idx:
idx_map.append(i)

for i in range(len(topk_indices)):
idx_expanded = idx_map[topk_indices[i]]
topk_tensor[idx_expanded, idx_expanded, idx_expanded, idx_expanded] = grad[idx_expanded, idx_expanded, idx_expanded, idx_expanded]

``````

Thanks in advance for the help!

You could calculate the `topk` of the flattened tensor and use this implementation to unravel the indices:

``````def unravel_index(index, shape):
out = []
for dim in reversed(shape):
out.append(index % dim)
index = index // dim
return tuple(reversed(out))

x = torch.randn(2, 3, 4, 5)
res = torch.topk(x.view(-1), k=3)

idx = unravel_index(res.indices, x.size())
print(x[idx] == res.values)
> tensor([True, True, True])
``````