I’m looking for a function - for a constexpr - that’ll map a type T
to a ScalarType
. I imagine it’d be called at::ScalarType<T>()
or the like, but grepping through the PyTorch codebase I can’t find anything promising.
Concretely, I want to write something like
template <typename T>
at::Tensor empty() {
return at::empty({1}, at::dtype(at::ScalarType<T>());
}
so that a call like empty<float>()
gives me back a at::kFloat
tensor.
I understand this is not usually the way the conversion happens, and that control flow should generally be generic -> specific. I know that I could write this as empty(at::ScalarType)
and then use the DISPATCH
macros to do the templating. For the particular chunk of code I’m working on though, it’d be useful and it feels like it should exist, I just can’t find it.
(More generally, if anyone has advice on searching for a specific signature in a large codebase, that’d be great too. Best I can find is CppDepend, which is $$$)