I opened a issue recently in GitHub. The documented speed of fp16 GEMM on 4090 is 330 TFLOPS(with fp16 accumulation) and 165FLOPS(with fp32 accumulation). I run the following benchmark code on my 4090 machine.
import torch
print(torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction)
x = torch.randn(4096, 4096, device="cuda:0", dtype=torch.float16)
y = torch.randn(4096, 4096, device="cuda:0", dtype=torch.float16)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# warmup
for _ in range(1000):
out = x.mm(y)
start.record()
for _ in range(1000):
out = x.mm(y)
end.record()
end.synchronize()
time = start.elapsed_time(end)
print(f"TFLOPs {4096*4096*4096*2/1e9/(time/1000)}")
The output is:
True
TFLOPs 168.63091825340712
The default value of allow_fp16_reduced_precision_reduction
is true. But Pytorch still uses fp32 as the accumulation mode for fp16 GEMM. Changing the accumulation mode to fp16 will double the performance on some GPU arch like 4090.
I want to submit a PR to change this behavior, making the option to actually use low-precision accumulation mode on arches that can benefit from it. Will this feature surprise people because of the change in precision? Maybe we should set the default value of this option to be false?