Hi All!
For your delectation:
Short story: The intel “xpu” is about half as fast as the nvidia gpu and is about six times
faster than running on the cpu.
Pytorch has recently started supporting intel gpus (on a prototype basis). See
“Getting Started on Intel GPU.”
Specific hardware: thinkpad p16v gen2 with intel arc graphics integrated into the
intel core ultra 9 185h main processor and a separate nvidia rtx 3000 ada generation
laptop gpu.
The timing test consists of simple model with three Linear
s (no convolutions) and a
plain-vanilla SGD
training loop running with float32
.
The top-line timing results are (with time
in seconds):
device: cuda time: 49.1860 nBatch: 100000 nEpoch: 1000
device: xpu time: 100.8192 nBatch: 100000 nEpoch: 1000
device: cpu time: 582.4177 nBatch: 100000 nEpoch: 1000
The main limitations: This is an unrealistically simple test model (designed to saturate
the gpu). Various other tensor operations and more realistic models may run faster or
slower. I only tested float32
.
The main conclusions:
These timings are repeatable and are probably good to a couple of percent.
The integrated arc graphics xpu is significantly faster than the cpu, but not as fast as
the nvidia gpu.
A gpu-intensive job and an xpu-intensive job may be run simultaneously without
detectable degradation to either. (Also, the xpu uses the main system memory and
does not compete with the nvidia gpu memory.) (Running the model simultaneously
on the cpu and xpu slows the xpu job by about 80%. I didn’t disentangle by how much
the cpu job was slowed.)
As is typical, this machine has significantly more system memory than dedicated gpu
memory, so larger models could be accelerated on the xpu than on the gpu.
I haven’t used the xpu version of pytorch very much. I have seen some little hiccups,
but no major problems like wrong answers or big pieces of missing functionality.
Here is the timing script:
import torch
print (torch.__version__)
import time
torch.manual_seed (2025)
device = 'cpu'
if torch.cuda.is_available():
print ('version.cuda:', torch.version.cuda)
print (torch.cuda.get_device_name())
print (torch.cuda.get_device_properties())
device = 'cuda'
if torch.xpu.is_available():
print ('version.xpu:', torch.version.xpu)
print (torch.xpu.get_device_name())
print (torch.xpu.get_device_properties())
device = 'xpu'
print ('device:', device)
vBatch = 100000
nBatch = 100000
nHidden = 512
nEpoch = 1000
# nPrint = 100
def fitFunction (x):
return (1 * x).sin()
lossFn = torch.nn.MSELoss()
model = torch.nn.Sequential (
torch.nn.Linear (1, nHidden),
torch.nn.Sigmoid(),
torch.nn.Linear (nHidden, nHidden),
torch.nn.Sigmoid(),
torch.nn.Linear (nHidden, 1)
)
opt = torch.optim.SGD (model.parameters(), lr = 0.01, momentum = 0.9)
model.to (device)
inputVal = torch.randn (vBatch, 1, device = device)
targetVal = fitFunction (inputVal)
lossInit = lossFn (model (inputVal), targetVal)
print ('lossInit:', lossInit)
if device == 'cuda': torch.cuda.synchronize()
if device == 'xpu': torch.xpu.synchronize()
tBeg = time.time()
for i in range (nEpoch):
inp = torch.randn (nBatch, 1, device = device)
trg = fitFunction (inp)
loss = lossFn (model (inp), trg)
opt.zero_grad()
loss.backward()
# if i % nPrint == 0 or i >= nEpoch:
# print ('i:', i, 'loss:', loss.detach(), ' time:', '{0:.2f}'.format (time.time() - tBeg))
opt.step()
if device == 'cuda': torch.cuda.synchronize()
if device == 'xpu': torch.xpu.synchronize()
tEnd = time.time()
lossFinl = lossFn (model (inputVal), targetVal)
print ('lossFinl:', lossFinl)
print ('device:', device, ' time:', '{0:.4f}'.format (tEnd - tBeg), ' nBatch:', nBatch, ' nEpoch:', nEpoch)
And here are the timing results:
nvidia gpu:
2.6.0+cu126
version.cuda: 12.6
NVIDIA RTX 3000 Ada Generation Laptop GPU
_CudaDeviceProperties(name='NVIDIA RTX 3000 Ada Generation Laptop GPU', major=8, minor=9, total_memory=7933MB, multi_processor_count=36, uuid=f7fb8822-dc6a-36ef-c5e3-2665db44af42, L2_cache_size=32MB)
device: cuda
lossInit: tensor(0.7304, device='cuda:0', grad_fn=<MseLossBackward0>)
lossFinl: tensor(0.0290, device='cuda:0', grad_fn=<MseLossBackward0>)
device: cuda time: 49.1860 nBatch: 100000 nEpoch: 1000
intel “xpu”:
2.6.0+xpu
version.xpu: 20250001
Intel(R) Arc(TM) Graphics
_XpuDeviceProperties(name='Intel(R) Arc(TM) Graphics', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.32567+18', total_memory=29184MB, max_compute_units=128, gpu_eu_count=128, gpu_subslice_count=8, max_work_group_size=1024, max_num_sub_groups=128, sub_group_sizes=[8 16 32], has_fp16=1, has_fp64=1, has_atomic64=1)
device: xpu
lossInit: tensor(0.7319, device='xpu:0', grad_fn=<MseLossBackward0>)
lossFinl: tensor(0.0293, device='xpu:0', grad_fn=<MseLossBackward0>)
device: xpu time: 100.8192 nBatch: 100000 nEpoch: 1000
cpu (pytorch with gpu “2.6.0+cu126” launched with “CUDA_VISIBLE_DEVICES=-1”):
2.6.0+cu126
device: cpu
lossInit: tensor(0.7256, grad_fn=<MseLossBackward0>)
lossFinl: tensor(0.0303, grad_fn=<MseLossBackward0>)
device: cpu time: 582.4177 nBatch: 100000 nEpoch: 1000
Best.
K. Frank