Support custom scalar types in PyTorch

This example:

illustrates creating a new tensor type using an existing supported scalar type, which in this case is ScalarType::ComplexFloat.

Should/could supporting a new scalar type, say bfloat, int4, etc., be done through cpp extension or does it need to be added to the aten/c10 code in the PyTorch repo where the existing scalar types are defined (ScalarType.h, etc.)?

In either case, is there a tutorial or checklist for this?

some fundamental parts need to be added to aten/C10, similar to how we added ComplexFloat.

The rest can be done out-of-source.

If it’s bfloat16, we are interested in adding it into aten/c10 for some accelerators.


I see that there is already some boilerplate code pertaining to complex float types. Things like ScalarType, TypeID, deferred/lazy registration of the complex types, etc.

Yes, I was referring to bfloat16. Glow had an issue open about adding bfloat type too ( But, IIANM, it has been deferred.

I’ll explore more.