import torch
try:
import torch_xla
import torch_xla.core.xla_model as xm
except ModuleNotFoundError:
print("won't use tpu-xla")
xm = None
# print(torch_xla.__version__,xm.xla_device())
import time
size=20000
if xm is not None:
x = torch.randn(size,size).to(xm.xla_device())
y = torch.randn(size,size).to(xm.xla_device())
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(size,size).to(device)
y = torch.randn(size,size).to(device)
print(x.device, y.device)
times =[]
for j in range(10):
s= time.time()
for i in range(10000):
x.add(y).matmul(x)
e= time.time()
times.append(e-s)
average = sum(times)/len(times)
print(average)
# avg:xpu:0.13987584114074708
Is there such a huge speed difference, or am I doing something wrong?
This is observed for smaller sizes as well, but the difference is smaller.
Could one expect like 100x speed up in training a model with XLA?
the test tries to reduce anything that’s not from the calculation by using a large matrix multiplication that’d consume most of the computing time; im unsure why the timing wouldn’t be pointing in the right direction, could you be more explicit or give a simple example?
CUDA kernels are executed asynchronously, i.e. the host code can run ahead launching kernels while the GPU is busy with the actual kernel execution.
If you use host timers and want to measure the real GPU kernel execution time, you would need to synchronize the code before starting and stopping the timers.
Here is a small example showing this:
import torch
import time
from torch.profiler import profile, record_function, ProfilerActivity
device = "cuda"
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
sort_by_keyword = device + "_time_total"
x = torch.randn(64, 1024, 1024, device=device)
# warmup
for _ in range(10):
_ = torch.matmul(x, x)
nb_iters = 100
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(nb_iters):
_ = torch.matmul(x, x)
torch.cuda.synchronize()
t1 = time.perf_counter()
print("{:.3f}ms/iter".format((t1 - t0) / nb_iters * 1000))
# 7.783ms/iter
t0 = time.perf_counter()
for _ in range(nb_iters):
_ = torch.matmul(x, x)
t1 = time.perf_counter()
print("{:.3f}ms/iter".format((t1 - t0) / nb_iters * 1000))
# 0.013ms/iter
with profile(activities=activities, record_shapes=True) as prof:
for _ in range(nb_iters):
_ = torch.matmul(x, x)
print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))
# --------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
# --------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# aten::matmul 0.13% 967.727us 0.57% 4.432ms 44.320us 0.000us 0.00% 770.696ms 7.707ms 100
# aten::bmm 0.30% 2.286ms 0.36% 2.765ms 27.647us 770.696ms 100.00% 770.696ms 7.707ms 100
# ampere_sgemm_128x128_nn 0.00% 0.000us 0.00% 0.000us 0.000us 770.696ms 100.00% 770.696ms 7.707ms 100
# aten::expand 0.04% 290.806us 0.05% 366.557us 1.833us 0.000us 0.00% 0.000us 0.000us 200
# aten::as_strided 0.01% 75.751us 0.01% 75.751us 0.379us 0.000us 0.00% 0.000us 0.000us 200
# aten::reshape 0.02% 121.196us 0.03% 260.900us 1.304us 0.000us 0.00% 0.000us 0.000us 200
# aten::view 0.02% 139.704us 0.02% 139.704us 0.699us 0.000us 0.00% 0.000us 0.000us 200
# cudaLaunchKernel 0.06% 478.226us 0.06% 478.226us 4.782us 0.000us 0.00% 0.000us 0.000us 100
# aten::_unsafe_view 0.01% 72.106us 0.01% 72.106us 0.721us 0.000us 0.00% 0.000us 0.000us 100
# cudaDeviceSynchronize 99.43% 768.014ms 99.43% 768.014ms 768.014ms 0.000us 0.00% 0.000us 0.000us 1
# --------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# Self CPU time total: 772.446ms
# Self CUDA time total: 770.696ms
As you can see, the properly synchronized code shows the same kernel execution time as the profiler.