with H100 supporting FP8, is there any plan to support FP8 training in pytorch?
Looks like the only alternative is to use FP8 through nvidia transformer engine. Even transformer engine doesn’t support all nn modules like conv layers for FP8 support.
Hey @navmarri we actually have landed some fp8 primitives already in PyTorch. Specifically we have landed fp8e4m3 and fp8e5m2 datatypes and a private function (with no BC/FC guarantees) for doing scaled matmul on h100 machines.
For a quick example script of how you can use it, check out:
As I know (and also be mentioned in Nvidia/TransformerEngine)
RTX 4000 series (compute capibility 8.9) should also support FP8 computing.
Does pytorch support that too?
or for now it is only for H100 and need further works for sm8.9 cards?