Background
Hi, I’ve been diving into the source code of torch.fft.fft2
, but I got some problem.
What I have done:
- Step 1: I called the code
torch.fft.fft2(img)
and I found the source of it intorch/csrc/autograd/generated/python_fft_functions.cpp
- Step 2: In
torch/csrc/autograd/generated/python_fft_functions.cpp
the following code callsfft_fft2
, which is located inaten/src/ATen/native/SpectralOps.cpp
auto dispatch_fft_fft2 = [](const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, c10::optional<c10::string_view> norm) -> at::Tensor { pybind11::gil_scoped_release no_gil; std::cout << "self.sizes()=" << self.sizes() << std::endl; return at::fft_fft2(self, s, dim, norm); };
- Step 3: In
aten/src/ATen/native/SpectralOps.cpp
, the following code calls_fft_c2c
, which is implemented intorch/csrc/autograd/generated/VariableType_1.cpp
Tensor fft_c2c_maybe_out( c10::string_view fname, const Tensor& out, const Tensor& input, IntArrayRef dim, int64_t norm, bool forward) { if (out.defined()) { TORCH_CHECK(out.is_complex(), fname, " expects a complex output tensor, but got ", out.scalar_type()); auto out_mut = out; return at::_fft_c2c_outf(input, dim, norm, forward, out_mut); } return at::_fft_c2c(input, dim, norm, forward); }
- Step 4: In
torch/csrc/autograd/generated/VariableType_1.cpp
, the following code calls_fft_c2c_symint
auto _tmp = ([&]() { at::AutoDispatchBelowADInplaceOrView guard; return at::redispatch::_fft_c2c_symint(ks & c10::after_autograd_keyset, self_, dim, normalization, forward); })();
Then I failed to find the implement of this call at::redispatch::_fft_c2c_symint
.
Question
My first question is where is the core logic of torch.fft.fft2
?
If you cannot solve this question directly, then my second question is where is the redispatch function of _fft_c2c_symint
?
Here’s my third question.
It seems like there’s a singleton dispatcher registry in aten/src/ATen/core/dispatch/Dispatcher.h
, where they register & load the redispatch function through mapping:
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet);
And the acquired redispatch function is of class KernelFunction
. How can I get the name of that redispatch function given that KernelFunction
?