Torch Profiler: GPU substantially slower than CPU

Synopsis: Training and inference on a GPU is dramatically slower than on any CPU.

Setup:

  • Training a highly customized Transformer model on an Azure VM (Standard NC6s v3 [6 vcpus, 112 GiB memory]) with a Tesla V100 (Driver Version: 550.54.15 & CUDA Version: 12.4).
  • The dataset is not very large (e.g. 1 GB) with dimensions [12000, 51, 48] using mini-batches of size 256.
  • All data are loaded into tensors and sent to GPU memory (e.g. to(device)) before even instantiating the model.
  • model is sent to(device) and then compiled via model = torch.compile(model, mode=“reduce-overhead”).
  • Call model training via torch.autocast(device_type=device, dtype=torch.bfloat16, enabled=False).

Situation: Using Torch’s built-in Profiler, I am noticing a dramatic slowdown when using GPU/CUDA rather than calculating everything on CPU. I do not believe this has anything to do with data flowing between the CPU and GPU given my comments above, but could be wrong.

The following two tables are exactly the same datasets and run on the same machine. The only difference between the CPU runtime and the GPU/CUDA runtime is where the data is being stored (e.g. they are both Torch.Tensors sent to(device)) and where compute occurs.

What’s Weird: model_inference is 27.5x slower on GPU than on CPU and model_inference is called 3x times in the CUDA version…


CPU Runtime

Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
model_inference 1.90% 496.389ms 100.00% 26.178s 26.178s 1
aten::empty 0.55% 144.920ms 0.55% 144.920ms 19.006us 7625
aten::random_ 0.00% 12.914us 0.00% 12.914us 12.914us 1
aten::item 0.00% 3.033us 0.00% 4.424us 4.424us 1
aten::_local_scalar_dense 0.00% 1.391us 0.00% 1.391us 1.391us 1
enumerate(DataLoader)#SingleProcessDataLoaderIter.… 0.64% 167.173ms 1.34% 349.861ms 9.207ms 38
aten::randperm 0.00% 436.807us 0.00% 876.528us 219.132us 4
aten::scalar_tensor 0.00% 7.450us 0.00% 7.450us 3.725us 2
aten::resize_ 0.01% 1.711ms 0.01% 1.711ms 2.554us 670
aten::resolve_conj 0.01% 2.982ms 0.01% 2.982ms 0.326us 9141

Self CPU time total: 26.178s


GPU/CUDA Runtime

Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
model_inference 0.00% 0.00us 0.00% 0.00us 0.00us 1439.411s 95.27% 1439.411s 719.705s 2
GraphLowering.run (dynamo_timed) 0.00% 0.00us 0.00% 0.00us 0.00us 59.987s 3.97% 59.987s 810.641ms 74
CachingAutotuner.benchmark_all_configs (dynamo_timed… 0.00% 0.00us 0.00% 0.00us 0.00us 5.193s 0.34% 5.193s 144.253ms 36
aten::fill_ 0.13% 950.923ms 0.17% 1.249s 28.885us 3.996s 0.26% 3.996s 92.384us 43249
aten::zero_ 0.13% 947.515ms 0.30% 2.238s 49.680us 0.000us 0.00% 3.994s 88.611us 45049
void at::native::vectorized_elementwise_kernel<4, at… 0.00% 0.00us 0.00% 0.00us 0.00us 3.956s 0.26% 3.956s 287.800us 13747
CachingAutotuner.benchmark_all_configs (dynamo_timed… 0.36% 2.618s 0.71% 5.23s 145.287ms 0.000us 0.00% 3.737s 103.808ms 36
model_inference 0.43% 3.184s 99.58% 731.765s 731.765s 0.000us 0.00% 2.663s 2.663s 1
Torch-Compiled Region 0.00% 11.368ms 0.52% 103.364ms 3.824s 0.000us 0.00% 2.298s 62.111ms 37
CompiledFunction 0.10% 768.567ms 0.52% 103.032ms 3.812s 158.932ms 0.01% 2.298s 62.111ms 37

Self CPU time total: 734.872s
Self CUDA time total: 1510.825s

Any insight as to what’s causing this would be greatly appreciated.

Update: I believe I might’ve discovered an error in my environment setup.

I installed Pytorch, targeting CUDA 12.4 via the following:

conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia

When checking the CUDA version, it shows as 10.1:

$ nvcc -V

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Sun_Jul_28_19:07:16_PDT_2019
Cuda compilation tools, release 10.1, V10.1.243

Checking CUDA versions installed on the machine shows only 12.2+ versions:

$ ls -l /usr/local | grep cuda

lrwxrwxrwx 1 root root 22 Aug 20 15:39 cuda → /etc/alternatives/cuda
lrwxrwxrwx 1 root root 25 May 1 15:09 cuda-12 → /etc/alternatives/cuda-12
drwxr-xr-x 16 root root 4096 Aug 20 15:25 cuda-12.2
drwxr-xr-x 15 root root 4096 May 3 14:34 cuda-12.3
drwxr-xr-x 15 root root 4096 May 1 15:09 cuda-12.4

So, I have installed Pytorch targeting CUDA 12.4 and have CUDA 12.2, 12.3, and 12.4 installed on the machine, but the environment is pointing towards CUDA 10.1, which is potentially the problem…

Has anyone experienced this before?

A few issues stand out in your profiling:

  • You are profiling the actual compilation step of torch.compile which is showing a long duration,
  • autocast is not enabled but I assume it’s neither enabled on the CPU (bfloat16 is also not supported on your Volta GPU)
  • Your locally installed CUDA toolkit won’t be used unless you build PyTorch from source or a custom CUDA extension.

Thank you for the suggestions ptrblck.

I am running the Profiler outside the minibatch_loader as such:

# profile torch run
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("model_inference"):

        # cycle all minibatches
        for batch, data in enumerate(batch_loader, 0):
            # training data
            x_, y_ = data[0].float(), data[1].float()
            x_, y_ = x_.to(device), y_.to(device)
            with torch.autocast(device_type=device, dtype=torch.float32, enabled=True):
                logits, loss = model(x_, y_)

            loss = loss / batches_per_epoch
            if train:
                loss.backward()  # accumulate the loss

            # accumulate
            loss_accum += loss.detach()
            logits_accum = torch.cat((logits_accum, logits.detach()), dim=0)
            y_accum = torch.cat((y_accum, y_), dim=0)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

So, it cycles through an entire epoch of minibatches one time before printing those results. I have messed around with doing inside each minibatch run, but the results are pretty much identical, just scaled down by n_batches.

I’ve subsequently changed the autocast dtype and enable params as suggested:

with torch.autocast(device_type=device, dtype=torch.float32, enabled=True):

and have matrix multiplication precision set to medium.

torch.set_float32_matmul_precision(“medium”)

I’m not very familiar with the last suggestion of building from source and am working to figure that out… If anyone knows of any step-by-step guides for setting up in a conda environment on a Linux server, that would be greatly appreciated.

You don’t need to build from source and the PyTorch binaries are fine unless you have a valid use case to use a locally installed CUDA toolkit.
Add warmup iterations to get rid of the compilation step and to profile the actual workload afterwards.
Changing the matmul precision won’t do anything on your device as it’s enabling TF32 on Ampere+ GPUs.

Hi @ptrblck - I did resolve my issue; the solution was co-linear to your recommendations. Thank you for taking the time to reply with suggestions.

After completely wiping my OS and starting with a fresh Ubuntu 22.04 install, Nvidia drivers, and CUDA, etc… I ultimately had two problems. Firstly, I wasn’t careful with the dtypes of my dataset flowing into the model and, secondly, torch.compile() is not working, which is unfortunate. I’ll break this down for others to follow:

Issue #1 - dtypes: I am using an Nvidia V100 on a MSFT Azure VM. The V100 is older hardware that supports only FP16 precision on its tensor cores. I found this out by searching for many days through GitHub forums to find the following [source: V100 can not supprt load_in_4bit and fp16? · Issue #71 · artidoro/qlora · GitHub] (hopefully Nvidia documentation improves). Anyways, I originally had FP64 and Int dtypes flowing through my dataloader. Once I fixed the dtypes to FP16, the V100 GPU was now able to match the speed of a CPU.

Note: AMP (Automatic Mixed Precision) was not able to fix this problem for some reason…

Issue #2 - torch.compile: After solving the above issue, a new runtime error popped up.

torch/fx/experimental/symbolic_shapes.py:4449] [0/0_1] xindex is not in var_ranges, defaulting to unknown range.

I had no idea how to interpret this, but “experimental/symbolic_shapes” suggested that this might have something to due with pre-compiling the model. There are a handful of GitHub discussions/issues reporting this (e.g. q0 is not in var_ranges,z0 is not in var_ranges,x0 is not in var_ranges, · Issue #2418 · MIC-DKFZ/nnUNet · GitHub) so I simply commented-out the following before runtime:

model = torch.compile(model, mode="reduce-overhead", fullgraph=True)

Success! It now works as intended.

I don’t know why torch.compiler doesn’t like my model. It’s pretty similar to GPT3, although I do have some custom layers and configurations, all built within the Pytorch ecosystem.

My model is now able to run substantially faster than a CPU. What used to take 3 min per epoch, now takes less than a dozen seconds.

Thank you again for your dedication to Pytorch and helping others.

1 Like