Hi, I’m looking to get the topk gradients of all rows, not topk of each row. For example, if I have a conv layer of shape [64, 64, 3, 3] and k=2, I only want 2 top values and their corresponding indices returned. Ultimately, I want a new tensor with a shape matching the dimensions of the original weight, with all elements zeroed out except the top k gradients.
Here is what I have now, which is quite with the mapping from flattened indices to expanded:
def get_topk_idx(grad):
grad_list = grad.detach().cpu().tolist()
stack = [enumerate(grad_list)]
path = [None]
while stack:
for path[-1], x in stack[-1]:
if isinstance(x, list):
stack.append(enumerate(x))
path.append(None)
else:
yield x, tuple(path)
break
else:
stack.pop()
path.pop()
def get_topk_grads(grad, num_gradients):
idx_map = [] # flat_idx --> (val, idx)
topk_idx = get_topk_idx(grad)
for i in topk_idx:
idx_map.append(i)
grad_flat = torch.abs(grad).flatten()
topk_values, topk_indices = grad_flat.topk(num_weights)
topk_tensor = torch.cuda.FloatTensor(grad.size()).fill_(0)
for i in range(len(topk_indices)):
idx_expanded = idx_map[topk_indices[i]][1]
topk_tensor[idx_expanded[0], idx_expanded[1], idx_expanded[2], idx_expanded[3]] = grad[idx_expanded[0], idx_expanded[1], idx_expanded[2], idx_expanded[3]]
return topk_tensor
Thanks in advance for the help!