Slow fp16 GEMM on 4090

I opened a issue recently in GitHub. The documented speed of fp16 GEMM on 4090 is 330 TFLOPS(with fp16 accumulation) and 165FLOPS(with fp32 accumulation). I run the following benchmark code on my 4090 machine.

import torch
print(torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction)
x = torch.randn(4096, 4096, device="cuda:0", dtype=torch.float16)
y = torch.randn(4096, 4096, device="cuda:0", dtype=torch.float16)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# warmup 
for _ in range(1000):
    out = x.mm(y)
start.record()
for _ in range(1000):
    out = x.mm(y)
end.record()
end.synchronize()
time = start.elapsed_time(end)
print(f"TFLOPs {4096*4096*4096*2/1e9/(time/1000)}")

The output is:

True
TFLOPs 168.63091825340712

The default value of allow_fp16_reduced_precision_reduction is true. But Pytorch still uses fp32 as the accumulation mode for fp16 GEMM. Changing the accumulation mode to fp16 will double the performance on some GPU arch like 4090.

I want to submit a PR to change this behavior, making the option to actually use low-precision accumulation mode on arches that can benefit from it. Will this feature surprise people because of the change in precision? Maybe we should set the default value of this option to be false?

I don’t fully understand the idea as the reduced precision reduction in e.g. split-k kernels is already allowed (assuming a split-k kernel is used). A “pure” FP16 kernel using FP16 as the compute and accumulation type is not a good idea as convergence issues were shown in the past.

Thanks for your reply! I have the same concerns with the precision problem. Mixed precision is important for NN training, but this feature is aimed at improving inference speed.

I’m currently optimizing the LLM inference speed on 4090. When changing the compute mode to FP16, there is very little perplexity increasing(0.009) in LLAMA2-7B, but 40% end2end speed up.

Can we add an option to support this kind of “pure” FP16 GEMM operation? When people turn on this option, they can have performance benefits at their own risk of precision issues.

I’m not too familiar with these metrics, but I assume this value is not considered large and would still yield the same outputs/predictions?

Could you create a feature request for the inference use case on GitHub so we could discuss it in more detail there?

Sure, the new issue is Support FP16 accumulation for faster LLM inference on 4090 like GPUs · Issue #123558 · pytorch/pytorch · GitHub. Hope to hear voices from the community!