How can I find the return dtype of every aten function? I’m working on a delayed tensor, and I need the dtype to initialize my TensorImpl. But without executing the operation, I don’t have a way to knowing its dtype.
I’ve seen the file torch/share/ATen/Declarations.yaml, but it only mentions Tensor/Scalar as return type, but not the dtype.
The dtype of an op is mostly equal to its input tensor arguments, except for some that return booleans (e.g., eq/gt/ne, is*). Some ops also upcast, e.g. sum(bool) → long.
Is this information specified in a machine-readable way somewhere?
Worst case I was thinking of maybe writing a program that executes every single op with a singleton tensor of each of the dtypes and records the result. Could that work?
I’m not sure if there’s a better option out there, but structured kernels (and specifically meta tensors) might help with this a bit. If you construct meta tensors and pass them to an aten op, the returned tensor will have the correct size/dtype, but without any actual data allocation / computation being performed.
a = torch.ones(2, device='meta', dtype=torch.int32)
b = torch.ones(2, device='meta', dtype=torch.int64)
c = a + b # returns a meta tensor with correct output size / dtype, no actual computation is performed.
c.dtype # torch.int64
The wrinkle is that only operators that have been ported to structured support meta tensors, so if your goal is to use this for 100% of ops then this won’t work in the near term. There’s an open issue with a running list of supported/unsupported ops here.
Thanks for pointing out to this meta device. I was not aware of it. Very interesting indeed!
Though I guess for this specific question it’s maybe not super useful as the cost of running a kernel on a singleton tensor is fine. And I just need to run this once in a while, so not terrible even if it takes half an hour to go through all the ops.
I was really hoping the typing rules were available somewhere… Since there are all these yaml files with a lot of info, I was hopeful
I went ahead and created a little program that calls all ATen functions and produces a list with the typing rule used by each function.
Some interesting statistics:
ALL Bool: 37
ALL Byte: 4
ALL Char: 1
ALL Double: 3
ALL Float: 2
ALL Half: 1
ALL Int: 1
ALL Long: 17
ALL Short: 1
Num Types: 33
These are names I gave to each typing rule. It’s interesting that a few rules are used only once. This may point to bugs, as there should be no reason to have such exceptions.
Why I’m posting this stuff here is because it may be interesting for you guys. For example, it would be possible to generate the code for the Meta tensors automatically from this data. AFAIU, Meta tensors only compute dtype + shape information. Here I have the dtype, and I’m working next on the shape inference.
Another use is for testing. Some of the calls crash with some inputs, for example, so you can use this driver for fuzzing. Might be also helpful to compare the returned dtypes across different devices.
Anyway, I’ll leave that to you. I need this data for other purposes.
How it works:
- I use a script to generate a file with calls to every ATen function using default parameters. This script also generates a ninja file so I can run things in parallel.
- Then I wrote a driver that goes through each function and calls them with different tensors (one call per each combination of dtypes)
- The C++ code with the typing rules I wrote by hand
- The result after running ninja is the types.txt file indicating the typing rule per ATen function.
cc @ailzhang @ezyang