I noticed that matrix multiplication with torch.float32 is more than 50 times faster than with torch.float64.
import torch
torch.set_default_device(‘cuda:0’) # cuda:0 is A6000
def timer_torch(x,y):
z = x @ y
torch.cuda.synchronize()
return
float32
x = torch.randn(100,1000000)
y = torch.randn(1000000,100)
timeit timer_torch(x,y)
1.68 ms ± 4.51 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
float64
x = torch.randn(100,1000000).to(torch.float64)
y = torch.randn(1000000,100).to(torch.float64)
timeit timer_torch(x,y)
68 ms ± 325 ns per loop (mean ± std. dev. of 7 runs, 10 loops each)
So on my A6000 with torch 2.1 and cuda 12.1, float32 is 40 times faster than float64.
Is this the expected behavior? I want to make sure this is not a result of some mismatch in libraries that would produce such dramatic difference.
If the result is expected, are there strategies to leverage off GPU when using big matrices where a higher 64 bit precision is necessary?