FP8 support on H100

with H100 supporting FP8, is there any plan to support FP8 training in pytorch?
Looks like the only alternative is to use FP8 through nvidia transformer engine. Even transformer engine doesn’t support all nn modules like conv layers for FP8 support.

@drisspg has been working on this

@marksaroufim is there a timeline for the feature to be available?

Hey @navmarri we actually have landed some fp8 primitives already in PyTorch. Specifically we have landed fp8e4m3 and fp8e5m2 datatypes and a private function (with no BC/FC guarantees) for doing scaled matmul on h100 machines.

For a quick example script of how you can use it, check out:

@drisspg awesome!
does this work with FP8 precision training using FSDP on H100s?

As I know (and also be mentioned in Nvidia/TransformerEngine)
RTX 4000 series (compute capibility 8.9) should also support FP8 computing.
Does pytorch support that too?
or for now it is only for H100 and need further works for sm8.9 cards?

Is the mixed precision support planned for FP8?

There was an announcement in JAN in the dev forum:

TransformerEngine supports FP8 for a long time now. Alternatively you can also play around with pytorch-labs/float8_experimental.

Thanks, my point was more related to pytorch AMP support.