Hi @KFrank!
Thanks for the detailed response! After reading your comments, it did give me a thought about potentially using functorch
and seeing if that affects performance. This problem seems to primarily affect larger matrices, and using functorch
actually seems to speed up the GPU to be better than the CPU for small matrices. I’ve got some results below with all time measured in seconds.
version: 1.12.0a0+git7c2103a
CUDA: 11.6
Size of Tensor: torch.Size([1000, 32, 32])
CPU time: 0.05937480926513672
GPU time: 0.5319147109985352
FUNC time: 0.004615306854248047
Size of Tensor: torch.Size([1000, 64, 64])
CPU time: 0.1272900104522705
GPU time: 0.8110277652740479
FUNC time: 0.7586314678192139
Size of Tensor: torch.Size([1000, 128, 128])
CPU time: 0.4362821578979492
GPU time: 1.8971400260925293
FUNC time: 1.9026517868041992
Size of Tensor: torch.Size([1000, 256, 256])
CPU time: 1.5289278030395508
GPU time: 5.590537071228027
FUNC time: 5.531857967376709
Size of Tensor: torch.Size([10000, 32, 32])
CPU time: 0.4745197296142578
GPU time: 0.515204906463623
FUNC time: 0.04160284996032715
Size of Tensor: torch.Size([10000, 64, 64])
CPU time: 1.1662471294403076
GPU time: 6.958388805389404
FUNC time: 6.955511808395386
Size of Tensor: torch.Size([10000, 128, 128])
CPU time: 3.8899362087249756
GPU time: 18.134103775024414
FUNC time: 18.80874514579773
To reproduce these results the script is below,
from time import time
import torch
from functorch import vmap
veigh = vmap(torch.linalg.eigh)
#same as vmap(torch.linalg.eigh, in_dims=(0))(matrices)
print("version: ",torch.__version__)
print("CUDA: ",torch.version.cuda, "\n")
for B in [1000]:
for N in [32, 64, 128, 256]:
matrices = torch.randn(B, N, N)
matrices = matrices @ matrices.transpose(-2,-1)
torch.cuda.synchronize()
t1=time()
torch.linalg.eigh(matrices)
torch.cuda.synchronize()
t2=time()
cpu_time = t2-t1
matrices = matrices.to(torch.device('cuda'))
torch.cuda.synchronize()
t1=time()
torch.linalg.eigh(matrices)
torch.cuda.synchronize()
t2=time()
gpu_time = t2-t1
torch.cuda.synchronize()
t1=time()
veigh(matrices)
torch.cuda.synchronize()
t2=time()
func_time=t2-t1
print("Size of Tensor: ",matrices.shape)
print("CPU time: ",cpu_time)
print("GPU time: ",gpu_time)
print("FUNC time: ",func_time, "\n")