Get top k indices/values of all rows

Hi, I’m looking to get the topk gradients of all rows, not topk of each row. For example, if I have a conv layer of shape [64, 64, 3, 3] and k=2, I only want 2 top values and their corresponding indices returned. Ultimately, I want a new tensor with a shape matching the dimensions of the original weight, with all elements zeroed out except the top k gradients.

Here is what I have now, which is quite with the mapping from flattened indices to expanded:

def get_topk_idx(grad): 
  grad_list =  grad.detach().cpu().tolist()
  stack = [enumerate(grad_list)]
  path = [None]
  while stack:
      for path[-1], x in stack[-1]:
          if isinstance(x, list):
              stack.append(enumerate(x))
              path.append(None)
          else:
              yield x, tuple(path)
          break
      else:
          stack.pop()
          path.pop()

def get_topk_grads(grad, num_gradients):
  idx_map = [] # flat_idx --> (val, idx)
  topk_idx = get_topk_idx(grad)
  for i in topk_idx:
    idx_map.append(i)

  grad_flat = torch.abs(grad).flatten()
  topk_values, topk_indices = grad_flat.topk(num_weights)
  topk_tensor = torch.cuda.FloatTensor(grad.size()).fill_(0)

  for i in range(len(topk_indices)):
    idx_expanded = idx_map[topk_indices[i]][1]
    topk_tensor[idx_expanded[0], idx_expanded[1], idx_expanded[2], idx_expanded[3]] = grad[idx_expanded[0], idx_expanded[1], idx_expanded[2], idx_expanded[3]]

  return topk_tensor

Thanks in advance for the help!

You could calculate the topk of the flattened tensor and use this implementation to unravel the indices:

def unravel_index(index, shape):
    out = []
    for dim in reversed(shape):
        out.append(index % dim)
        index = index // dim
    return tuple(reversed(out))


x = torch.randn(2, 3, 4, 5)
res = torch.topk(x.view(-1), k=3)

idx = unravel_index(res.indices, x.size())
print(x[idx] == res.values)
> tensor([True, True, True])