Why pytorch cuda tensor need two dispatch key, CUDA and AutogradCUDA?

I write a test code

import torch

device = torch.device("cuda")
x = torch.randn(2, 2).to(device)
y = torch.randn(2).to(device)

x.requires_grad_(True)
z = torch.add(x, y)

z.sum().backward()

Tensor on cuda has two dispatch key, CUDA and AutogradCUDA. I read pytorch dispatcher source code and find dispatch key AutogradCUDA will be excluded in VariableTypeEverything.cpp

at::Tensor add_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
  auto _tmp = ([&]() {
    at::AutoDispatchBelowADInplaceOrView guard;
    return at::redispatch::add(ks & c10::after_autograd_keyset, self_, other_, alpha);
  })();
}

Finally, pytorch dispatch to wrapper_CUDA_add_Tensor function which is previously registered into pytorch with CUDA dispatch key. Then do the real computation and get the results

My question is why we need AutogradCUDA when create cuda tensor? What will pytorch do for dispatch key AutogradCUDA?I guess there are some relation between pytorch dispatcher and pytorch autograd system, but I can’t figure out the connection.

I would assume the Autograd dispatching is triggered by calling x.requires_grad_(True) which let’s Autograd track operations on x and not by the usage of the CUDA backend.

Thank you very much for your reply.

I printed the call chain, the key call sequence is as follows

  1. torch::autograd::THPVariable_add at pytorch/torch/csrc/autograd/generated/python_torch_functions_2.cpp
  2. torch::autograd::VariableType::(anonymous namespace)::add_Tensor at pytorch/torch/csrc/autograd/generated/VariableType_2.cpp

VariableType_2.cpp:add_Tensor function code shows as following

at::Tensor add_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
  auto& self_ = unpack(self, "self", 0);
  auto& other_ = unpack(other, "other", 1);
  auto _any_requires_grad = compute_requires_grad( self, other );
  
  (void)_any_requires_grad;
  auto _any_has_forward_grad_result = (isFwGradDefined(self) || isFwGradDefined(other));
  (void)_any_has_forward_grad_result;
  std::shared_ptr<AddBackward0> grad_fn;
  if (_any_requires_grad) {
    grad_fn = std::shared_ptr<AddBackward0>(new AddBackward0(), deleteNode);
    grad_fn->set_next_edges(collect_next_edges( self, other ));
    grad_fn->other_scalar_type = other.scalar_type();
    grad_fn->alpha = alpha;
    grad_fn->self_scalar_type = self.scalar_type();
  }

  auto _tmp = ([&]() {
    at::AutoDispatchBelowADInplaceOrView guard;
    return at::redispatch::add(ks & c10::after_autograd_keyset, self_, other_, alpha);
  })();
  auto result = std::move(_tmp);
  if (grad_fn) {
      set_history(flatten_tensor_args( result ), grad_fn);
  }
  return result;
}

And add_Tensor has registered to dispatch key autograd at pytorch/torch/csrc/autograd/generated/VariableType_2.cpp

TORCH_LIBRARY_IMPL(aten, Autograd, m) {
  m.impl("add.Tensor", TORCH_FN(VariableType::add_Tensor));
}

So I guess that pytorch first dispatched to add_Tensor function according to the dispatch key AutogradCUDA. In add_Tensor, it builds the nodes and edges required for the computation graph for pytorch autograd.

But I still haven’t figured out how pytorch dispatch tensor with dispatch key AutogradCUDA to dispatch key Autograd specific function. After all, add_Tensor is registered to dispatch key Autograd, but the tensor which dispatch key set contains AutogradCUDA is still dispatched to add_Tensor function.