What does the `use_fast_accum` option do in `torch._scaled_mm`

Recent PyTorch support for FP8 in the torch._scaled_mm has a use_fast_accum flag that I have found to increase throughput by a noticable amount.
However, even after going through the CUDA code, I was unable to find out what this option does and what potential effects it may have on the matrix multiplication outputs.
May I ask for help on where to find detailed documentation on the matter?
If such documentation is not available, I would like to request that it be added, maybe in torch.ao.

Many thanks in advance!

This flag enables CUBLASLT_MATMUL_DESC_FAST_ACCUM here which is defined as:

Flag for managing FP8 fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision. Default value: 0 - fast accumulation mode is disabled

in the cuBLAS docs.

1 Like