How is topk selection implemented in torch?

I’m using torch.topk and I would like to find information about how top k selection is implemented in a differentiable way (with respect to the top k values).

Any information (also just a description in words or pseudocode) is most welcome.

During forward, it will save the indices that form the top k values. Then for backward, it will pass the gradient through those indices only.

Code for forward:

std::tuple<at::Tensor,at::Tensor> topk(c10::DispatchKeySet ks, const at::Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted) {
  auto& self_ = unpack(self, "self", 0);
  auto _any_requires_grad = compute_requires_grad( self );
  (void)_any_requires_grad;
  std::shared_ptr<TopkBackward> grad_fn;
  if (_any_requires_grad) {
    grad_fn = std::shared_ptr<TopkBackward>(new TopkBackward(), deleteNode);
    grad_fn->set_next_edges(collect_next_edges( self ));
    grad_fn->self_sizes = self.sizes().vec();
    grad_fn->dim = dim;
  }
  at::Tensor values;
  at::Tensor indices;
  auto _tmp = ([&]() {
    at::AutoDispatchBelowADInplaceOrView guard;
    return at::redispatch::topk(ks & c10::after_autograd_keyset, self_, k, dim, largest, sorted);
  })();
  std::tie(values, indices) = std::move(_tmp);
  if (grad_fn) {
      set_history(flatten_tensor_args( values ), grad_fn);
  }
  throw_error_for_complex_autograd(values, "topk");
  TORCH_CHECK(!(isFwGradDefined(self)), "Trying to use forward AD with topk that does not support it.");
  if (grad_fn) {
    grad_fn->indices_ = SavedVariable(indices, true);
  }
  return std::make_tuple(std::move(values), std::move(indices));
}

Code for backward

variable_list TopkBackward::apply(variable_list&& grads) {
  std::lock_guard<std::mutex> lock(mutex_);

  IndexRangeGenerator gen;
  auto self_ix = gen.range(1);
  variable_list grad_inputs(gen.size());
  auto& grad = grads[0];
  auto indices = indices_.unpack(shared_from_this());
  bool any_grad_defined = any_variable_defined(grads);
  if (should_compute_output({ self_ix })) {
    auto grad_result = any_grad_defined ? (value_selecting_reduction_backward(grad, dim, indices, self_sizes, true)) : Tensor();
    copy_range(grad_inputs, self_ix, grad_result);
  }
  return grad_inputs;
}
1 Like