How to Use Custom fp8 to fp16 Datatype Represented in uint8 in PyTorch

Hi everyone,

I’ve been working on a project where I wrote a custom datatype conversion from fp16 to fp8 (E4M3 and E5M2), which is represented in uint8 in PyTorch. I’m a bit stuck on how to use this datatype properly in my PyTorch code. I want to add this new data type to PyTorch so I can use it for mixed precision training later.

Could someone guide me on the best practices for implementing and utilizing this custom datatype in PyTorch, considering that it is represented in uint8?

torchao’s Float8Tensor defined here could be a good starter for your custom tensor class.