I’m using torch.topk and I would like to find information about how top k selection is implemented in a differentiable way (with respect to the top k values).
Any information (also just a description in words or pseudocode) is most welcome.
I’m using torch.topk and I would like to find information about how top k selection is implemented in a differentiable way (with respect to the top k values).
Any information (also just a description in words or pseudocode) is most welcome.
During forward, it will save the indices that form the top k values. Then for backward, it will pass the gradient through those indices only.
Code for forward:
std::tuple<at::Tensor,at::Tensor> topk(c10::DispatchKeySet ks, const at::Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted) {
auto& self_ = unpack(self, "self", 0);
auto _any_requires_grad = compute_requires_grad( self );
(void)_any_requires_grad;
std::shared_ptr<TopkBackward> grad_fn;
if (_any_requires_grad) {
grad_fn = std::shared_ptr<TopkBackward>(new TopkBackward(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self ));
grad_fn->self_sizes = self.sizes().vec();
grad_fn->dim = dim;
}
at::Tensor values;
at::Tensor indices;
auto _tmp = ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
return at::redispatch::topk(ks & c10::after_autograd_keyset, self_, k, dim, largest, sorted);
})();
std::tie(values, indices) = std::move(_tmp);
if (grad_fn) {
set_history(flatten_tensor_args( values ), grad_fn);
}
throw_error_for_complex_autograd(values, "topk");
TORCH_CHECK(!(isFwGradDefined(self)), "Trying to use forward AD with topk that does not support it.");
if (grad_fn) {
grad_fn->indices_ = SavedVariable(indices, true);
}
return std::make_tuple(std::move(values), std::move(indices));
}
Code for backward
variable_list TopkBackward::apply(variable_list&& grads) {
std::lock_guard<std::mutex> lock(mutex_);
IndexRangeGenerator gen;
auto self_ix = gen.range(1);
variable_list grad_inputs(gen.size());
auto& grad = grads[0];
auto indices = indices_.unpack(shared_from_this());
bool any_grad_defined = any_variable_defined(grads);
if (should_compute_output({ self_ix })) {
auto grad_result = any_grad_defined ? (value_selecting_reduction_backward(grad, dim, indices, self_sizes, true)) : Tensor();
copy_range(grad_inputs, self_ix, grad_result);
}
return grad_inputs;
}