Why is there such a large difference using XLA in this snippet?

The results for an operation are as follows:

  • CUDA enabled GPU=T4 calculation does not even finish after many seconds, maybe minutes. Same happens for CPUs.
  • XLA + CUDA enabled GPU is fast, approximately same speed as TPU
  • XLA enabled TPU are very fast.

Results and Repro

This can be issued in a Colab Notebook.

  • Remove tensorflow first using:
pip uninstall tensorflow -y -q
  • For the test with XLA+CUDA (only) run:
pip install torch==2.5.1 torch_xla==2.5.1 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.5.1-py3-none-any.whl -q

Then paste the code:

import torch

try:
 import torch_xla
 import torch_xla.core.xla_model as xm
except ModuleNotFoundError:
  print("won't use tpu-xla")
  xm = None

# print(torch_xla.__version__,xm.xla_device())

import time

size=20000
if xm is not None:
  x = torch.randn(size,size).to(xm.xla_device())
  y = torch.randn(size,size).to(xm.xla_device())
else:
  device = "cuda" if torch.cuda.is_available() else "cpu"
  x = torch.randn(size,size).to(device)
  y = torch.randn(size,size).to(device)

print(x.device, y.device)

times =[]
for j in range(10):
  s= time.time()
  for i in range(10000):
    x.add(y).matmul(x)
  e= time.time()
  times.append(e-s)
average = sum(times)/len(times)
print(average)

# avg:xpu:0.13987584114074708

Is there such a huge speed difference, or am I doing something wrong?

This is observed for smaller sizes as well, but the difference is smaller.

Could one expect like 100x speed up in training a model with XLA?

You are using host timers without any device synchronization so your profiling is not measuring the kernel execution time but the launch overheads.

the test tries to reduce anything that’s not from the calculation by using a large matrix multiplication that’d consume most of the computing time; im unsure why the timing wouldn’t be pointing in the right direction, could you be more explicit or give a simple example?

CUDA kernels are executed asynchronously, i.e. the host code can run ahead launching kernels while the GPU is busy with the actual kernel execution.
If you use host timers and want to measure the real GPU kernel execution time, you would need to synchronize the code before starting and stopping the timers.

Here is a small example showing this:

import torch
import time
from torch.profiler import profile, record_function, ProfilerActivity

device = "cuda"
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
sort_by_keyword = device + "_time_total"

x = torch.randn(64, 1024, 1024, device=device)

# warmup
for _ in range(10):
    _ = torch.matmul(x, x)

nb_iters = 100

torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(nb_iters):
    _ = torch.matmul(x, x)
torch.cuda.synchronize()
t1 = time.perf_counter()
print("{:.3f}ms/iter".format((t1 - t0) / nb_iters * 1000))
# 7.783ms/iter

t0 = time.perf_counter()
for _ in range(nb_iters):
    _ = torch.matmul(x, x)
t1 = time.perf_counter()
print("{:.3f}ms/iter".format((t1 - t0) / nb_iters * 1000))
# 0.013ms/iter


with profile(activities=activities, record_shapes=True) as prof:
    for _ in range(nb_iters):
        _ = torch.matmul(x, x)
print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))
# ---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
# ---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                aten::matmul         0.13%     967.727us         0.57%       4.432ms      44.320us       0.000us         0.00%     770.696ms       7.707ms           100  
#                   aten::bmm         0.30%       2.286ms         0.36%       2.765ms      27.647us     770.696ms       100.00%     770.696ms       7.707ms           100  
#     ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us     770.696ms       100.00%     770.696ms       7.707ms           100  
#                aten::expand         0.04%     290.806us         0.05%     366.557us       1.833us       0.000us         0.00%       0.000us       0.000us           200  
#            aten::as_strided         0.01%      75.751us         0.01%      75.751us       0.379us       0.000us         0.00%       0.000us       0.000us           200  
#               aten::reshape         0.02%     121.196us         0.03%     260.900us       1.304us       0.000us         0.00%       0.000us       0.000us           200  
#                  aten::view         0.02%     139.704us         0.02%     139.704us       0.699us       0.000us         0.00%       0.000us       0.000us           200  
#            cudaLaunchKernel         0.06%     478.226us         0.06%     478.226us       4.782us       0.000us         0.00%       0.000us       0.000us           100  
#          aten::_unsafe_view         0.01%      72.106us         0.01%      72.106us       0.721us       0.000us         0.00%       0.000us       0.000us           100  
#       cudaDeviceSynchronize        99.43%     768.014ms        99.43%     768.014ms     768.014ms       0.000us         0.00%       0.000us       0.000us             1  
# ---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
# Self CPU time total: 772.446ms
# Self CUDA time total: 770.696ms

As you can see, the properly synchronized code shows the same kernel execution time as the profiler.

1 Like

Makes sense, thanks for helping.