Since torch now supports fp8_e5m2
and fp8_e4m4fnz
data type, we could convert our fp16 tensor to fp8 like:
a = torch.randn(8, 16, dtype=torch.bfloat16, device='cuda')
a_f8 = a.to(torch.float8_e5m2)
but it seems like torch doesn’t provide any operator on float8 tensor, like we could not operate
mask = a < 0.5
so we have to utilize torchao
for such dtype of tensors?
Yes, use torchao
for native FP8 runs (or TransformerEngine
as an alternative).
CC @marksaroufim for visibility
Thanks for replying!
BTW, does PyTorch has its plan on supporting such operations? Use torchao
is kind a hacky way to do so.
I don’t know if the current derived dtype implementations will land in PyTorch directly. Could you describe what’s “hacky” about using torchao
?
1 Like
@cokespace2 mind sharing a bit more? We can certainly more stuff in core but just want to better understand what’s hacky? Is it the specific implementation in AO or is it any out of core implementation?
Well, for me I lean to use native PyTorch so I could have less library dependencies. The float8 dtype, as far as I understand, is more like float16 rather than qint8, while the latter contains attributes like qscheme
, qscales
.
For example, I believe that multiply 2 qint8 tensors needs a specific operator like qadd
, since we need to compute the qscales
of these input tensors and the output also come up with a new qscale
.
However, the multiply of 2 float8 dtype tensors could be same as multiply 2 float16 tensors, so we could utilize torch.matmul
directly.
Currently I have to import torchao
and build a Float8Linear
module to have a float8 matmul operator. Not mention that some other operations are not supported yet.
Maybe the word hacky is inaccurate, but I just thought that we can some basic operations supported in native PyTorch, so we could have more flexibility.
cc @ptrblck
No, that’s not the case as the scaling is needed for a proper training.
Take a look at e.g. this TransformerEngine doc explainig the workflow.
I know that during training with float8 we should scale out our float8 tensors, by doing so we can keep the stability of our model training.
But what I like to discuss is that the scaling and the computing can be separated. Just like we train model with float16, we use loss scaling, too. While we could still call many float16 functions in native PyTorch.
In a word, I think that float8 tensor could stand without any scaling factor and it can be used just like what we treat float16 tensors.