Any operator is supported on fp8 tensor?

Since torch now supports fp8_e5m2 and fp8_e4m4fnz data type, we could convert our fp16 tensor to fp8 like:

a = torch.randn(8, 16, dtype=torch.bfloat16, device='cuda')
a_f8 = a.to(torch.float8_e5m2)

but it seems like torch doesn’t provide any operator on float8 tensor, like we could not operate

mask = a < 0.5

so we have to utilize torchao for such dtype of tensors?

Yes, use torchao for native FP8 runs (or TransformerEngine as an alternative).
CC @marksaroufim for visibility

Thanks for replying!
BTW, does PyTorch has its plan on supporting such operations? Use torchao is kind a hacky way to do so.

I don’t know if the current derived dtype implementations will land in PyTorch directly. Could you describe what’s “hacky” about using torchao?

1 Like

@cokespace2 mind sharing a bit more? We can certainly more stuff in core but just want to better understand what’s hacky? Is it the specific implementation in AO or is it any out of core implementation?

Well, for me I lean to use native PyTorch so I could have less library dependencies. The float8 dtype, as far as I understand, is more like float16 rather than qint8, while the latter contains attributes like qscheme, qscales.

For example, I believe that multiply 2 qint8 tensors needs a specific operator like qadd, since we need to compute the qscales of these input tensors and the output also come up with a new qscale.

However, the multiply of 2 float8 dtype tensors could be same as multiply 2 float16 tensors, so we could utilize torch.matmul directly.

Currently I have to import torchao and build a Float8Linear module to have a float8 matmul operator. Not mention that some other operations are not supported yet.

Maybe the word hacky is inaccurate, but I just thought that we can some basic operations supported in native PyTorch, so we could have more flexibility.

cc @ptrblck

No, that’s not the case as the scaling is needed for a proper training.
Take a look at e.g. this TransformerEngine doc explainig the workflow.

I know that during training with float8 we should scale out our float8 tensors, by doing so we can keep the stability of our model training.

But what I like to discuss is that the scaling and the computing can be separated. Just like we train model with float16, we use loss scaling, too. While we could still call many float16 functions in native PyTorch.

In a word, I think that float8 tensor could stand without any scaling factor and it can be used just like what we treat float16 tensors.