nVidia A100 slower than laptop CPU while doing simple ANN regression

So I have access to a cluster of A100 GPUs which I’ve tried using to train simple neural networks, however, it seems like it’s running either at around the same speed as my laptop’s i7 CPU or slightly slower.
Given that both the CPU and the GPU in the cluster are far better than what is inside my laptop this seems very weird to me, and I’m probably missing something obvious, but can’t see what exactly.
The GPU is used to some degree, 20-30% so it’s not a simple case of me forgetting to move to the GPU.

The code tests out networks of different widths via cross-validation, but I have now removed most prints to minimise moving data back and between the CPU and GPU.
At the start the datasets are generated and moved to the GPU in one go. The entire data set should be about 5x8x10000 floats of float32, so about 1.6 MB which should easily fit on the 80 GB memory of the A100 so I don’t even use the DataLoader for mini batching just slicing in the training loop. The training data is same for all networks.

Here are the relevant bits of the code I’m using:

class CustomStandardScaler():
    def fit(self, x: torch.Tensor):
        self.mean = x.mean(dim=1, keepdim=True)
        self.std = x.std(dim=1, unbiased=False, keepdim=True)

    def transform(self, x):
        x -= self.mean
        x /= (self.std + 1e-7)
        return x

    def fit_transform(self, x):
        self.fit(x)
        return self.transform(x)


def train_loop(features, labels, model, loss_fn, opt, batchsize, verbose=False, device="cpu"):
    if len(features) != len(labels):
        raise ValueError("Features and labels must have same length")
    size = len(features)
    log = []
    no_batches = size // batchsize + 1
    rand_indices = torch.randperm(size, device=device)
    for i in range(no_batches):
        batch_start = i * batchsize
        batch_end = min((i + 1) * batchsize,size)
        if batch_end != batch_start:
            X = features[rand_indices[batch_start:batch_end]]
            y = labels[rand_indices[batch_start:batch_end]]
            pred = model(X)
            loss = loss_fn(pred, y)

            # Backpropagation
            opt.zero_grad()
            loss.backward()
            opt.step()

            # if i % 10 == 0:
            #    loss, current = loss.item(), i * len(X)
            #    if verbose:
            #        print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
            #    log.append(loss)

    return log


class NeuralNetwork(torch.nn.Module):
    def __init__(self, in_feats, no_nodes):
        super(NeuralNetwork, self).__init__()
        self.stack = torch.nn.Sequential(
            torch.nn.Linear(in_feats, no_nodes),
            torch.nn.PReLU(),
            torch.nn.Linear(no_nodes,  no_nodes),
            torch.nn.PReLU(),
            torch.nn.Linear(no_nodes,  no_nodes),
            torch.nn.PReLU(),
            torch.nn.Linear(no_nodes,  no_nodes),
            torch.nn.PReLU(),
            torch.nn.Linear(no_nodes, 1),
        )

    def forward(self, x):
        return self.stack(x)


if __name__ == "__main__":

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")

    (cv_training_feats, cv_training_labels), (cv_test_feats, cv_test_labels) = db.get_crossvalidation_sets(5, padding="max")

    # Move everything to GPU if possible
    cv_training_feats = torch.from_numpy(cv_training_feats).float().to(device)
    cv_training_labels = torch.from_numpy(cv_training_labels).float().to(device)
    cv_test_feats = torch.from_numpy(cv_test_feats).float().to(device)
    cv_test_labels = torch.from_numpy(cv_test_labels).float().to(device)

    sc = CustomStandardScaler()
    cv_training_feats = sc.fit_transform(cv_training_feats)
    cv_test_feats = sc.transform(cv_test_feats)
    for i in range(100, 101, 1):

        train_scores = torch.zeros([len(cv_training_feats)], dtype=torch.float)
        test_scores = torch.zeros([len(cv_training_feats)], dtype=torch.float)
        for j in range(len(cv_training_feats)):
            repeat = 5
            model = NeuralNetwork(len(flow_params) + 1, i).to(device)
            loss_fn = torch.nn.MSELoss()
            optimiser = torch.optim.Adam(model.parameters(), lr=0.001)

            epochs = 100
            log = []
            for t in range(epochs):
                curr_log = train_loop(cv_training_feats[j], cv_training_labels[j], model, loss_fn, optimiser,
                                      batchsize=256, verbose=False, device=device)
                log = log + curr_log

            with torch.no_grad():
                test_labels_pred = model(cv_test_feats[j])
                training_labels_pred = model(cv_training_feats[j])

                score = torch.nn.MSELoss()
                train_scores[j] = torch.sqrt(score(training_labels_pred, cv_training_labels[j]))
                test_scores[j] = torch.sqrt(score(test_labels_pred, cv_test_labels[j]))

                print(f"\r{i:03}/{j+1:02} nodes", end="")

        print(f"\r{i:03}/{len(cv_training_feats):02} nodes, ")
              # f"avg. training score: {train_scores.mean().item():.3g} "
              # f"σ = {train_scores.std().item():.3g}, "
              # f"avg. test score: {test_scores.mean().item():.3g} "
              # f"σ = {test_scores.std().item():.3g}")

Tried using torch.utils.bottleneck (with much fewer epochs) to find out what’s wrong, but it doesn’t seem to me that anything’s wrong there? Perhaps I could optimise things with randperm but surely that shouldn’t cause this much of a performance difference? I have pretty much run out of ideas so I’m asking here if anybody can help me with this.

Also logs from torch.utils.bottleneck:

`bottleneck` is a tool that can be used as an initial step for debugging
bottlenecks in your program.

It summarizes runs of your script with the Python profiler and PyTorch's
autograd profiler. Because your script will be profiled, please ensure that it
exits in a finite amount of time.

For more complicated uses of the profilers, please see
https://docs.python.org/3/library/profile.html and
https://pytorch.org/docs/master/autograd.html#profiler for more information.
Running environment analysis...
Running your script with cProfile
Using cuda device

100/01 nodes
100/02 nodes
100/03 nodes
100/04 nodes
100/05 nodes
100/05 nodes, avg. training score: 0.0553 σ = 0.00494, avg. test score: 0.0728 σ = 0.0227
Running your script with the autograd profiler...
Using cuda device

100/01 nodes
100/02 nodes
100/03 nodes
100/04 nodes
100/05 nodes
100/05 nodes, avg. training score: 0.0562 σ = 0.00255, avg. test score: 0.0787 σ = 0.0213
Using cuda device

100/01 nodes
100/02 nodes
100/03 nodes
100/04 nodes
100/05 nodes
100/05 nodes, avg. training score: 0.0591 σ = 0.00485, avg. test score: 0.073 σ = 0.0124
--------------------------------------------------------------------------------
  Environment Summary
--------------------------------------------------------------------------------
PyTorch 1.10.1+cu111 DEBUG compiled w/ CUDA 11.1
Running with Python 3.8 and CUDA 11.1.105

`pip3 list` truncated output:
numpy==1.23.1
torch==1.10.1+cu111
torchaudio==0.10.1+rocm4.1
torchvision==0.11.2+cu111
--------------------------------------------------------------------------------
  cProfile output
--------------------------------------------------------------------------------
         3839076 function calls (3801692 primitive calls) in 9.064 seconds

   Ordered by: internal time
   List reduced from 4922 to 15 due to restriction <15>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       29    2.207    0.076    2.207    0.076 {method 'to' of 'torch._C._TensorBase' objects}
     4125    0.935    0.000    0.935    0.000 {method 'run_backward' of 'torch._C._EngineBase' objects}
     3270    0.574    0.000    0.575    0.000 {built-in method posix.stat}
      125    0.380    0.003    4.180    0.033 pytorch-nn-test.py:52(train_loop)
      725    0.338    0.000    0.338    0.000 {built-in method io.open_code}
     4125    0.320    0.000    1.298    0.000 /home/aof26/fcdb/venv/lib/python3.8/site-packages/torch/optim/_functional.py:54(adam)
    41250    0.303    0.000    0.303    0.000 {method 'mul_' of 'torch._C._TensorBase' objects}
     8270    0.271    0.000    0.271    0.000 {built-in method torch._C._nn.linear}
    41250    0.257    0.000    0.257    0.000 {method 'add_' of 'torch._C._TensorBase' objects}
     4250    0.237    0.000    0.237    0.000 {built-in method randperm}
    32030    0.205    0.000    0.343    0.000 /usr/local/software/master/python/3.8/lib/python3.8/inspect.py:625(cleandoc)
    20625    0.145    0.000    0.145    0.000 {method 'sqrt' of 'torch._C._TensorBase' objects}
    20625    0.134    0.000    0.134    0.000 {method 'addcdiv_' of 'torch._C._TensorBase' objects}
    20625    0.124    0.000    0.124    0.000 {method 'addcmul_' of 'torch._C._TensorBase' objects}
     4135    0.107    0.000    0.107    0.000 {built-in method torch._C._nn.mse_loss}


--------------------------------------------------------------------------------
  autograd profiler output (CPU mode)
--------------------------------------------------------------------------------
        top 15 events sorted by cpu_time_total

----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                aten::linear         0.01%       7.000us        58.22%      50.182ms      50.182ms             1  
                 aten::addmm         0.12%     101.000us        58.20%      50.167ms      50.167ms             1  
                aten::expand        58.04%      50.028ms        58.04%      50.028ms      50.028ms             1  
                 aten::index         0.05%      44.000us        29.56%      25.481ms      25.481ms             1  
                    aten::to        29.41%      25.353ms        29.44%      25.375ms      25.375ms             1  
                 aten::index         0.04%      35.000us         6.53%       5.632ms       5.632ms             1  
               aten::reshape         6.42%       5.532ms         6.42%       5.533ms       5.533ms             1  
                aten::linear         0.01%       5.000us         3.16%       2.721ms       2.721ms             1  
                     aten::t         0.01%       8.000us         3.07%       2.649ms       2.649ms             1  
             aten::transpose         3.06%       2.640ms         3.06%       2.641ms       2.641ms             1  
                 aten::prelu         1.55%       1.340ms         1.57%       1.352ms       1.352ms             1  
    Optimizer.step#Adam.step         0.24%     203.000us         1.48%       1.280ms       1.280ms             1  
    Optimizer.step#Adam.step         0.60%     516.000us         1.28%       1.104ms       1.104ms             1  
    Optimizer.step#Adam.step         0.23%     196.000us         1.18%       1.016ms       1.016ms             1  
    Optimizer.step#Adam.step         0.22%     193.000us         1.18%       1.015ms       1.015ms             1  
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 86.201ms

--------------------------------------------------------------------------------
  autograd profiler output (CUDA mode)
--------------------------------------------------------------------------------
        top 15 events sorted by cpu_time_total

	Because the autograd profiler uses the CUDA event API,
	the CUDA time column reports approximately max(cuda_time, cpu_time).
	Please ignore this output if your code does not use CUDA.

--------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                      Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
--------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
  Optimizer.step#Adam.step        48.47%      60.612ms        49.13%      61.442ms      61.442ms      60.820ms        48.18%      61.437ms      61.437ms             1  
            aten::randperm         0.01%      12.000us        23.66%      29.588ms      29.588ms      20.000us         0.02%      29.583ms      29.583ms             1  
            aten::randperm        23.61%      29.518ms        23.64%      29.562ms      29.562ms      29.537ms        23.40%      29.556ms      29.556ms             1  
            aten::randperm         0.01%       9.000us        14.88%      18.607ms      18.607ms      14.000us         0.01%      18.603ms      18.603ms             1  
            aten::randperm        14.84%      18.557ms        14.86%      18.586ms      18.586ms      18.569ms        14.71%      18.582ms      18.582ms             1  
               aten::index         7.14%       8.926ms         7.23%       9.039ms       9.039ms       8.955ms         7.09%       9.035ms       9.035ms             1  
               aten::index         3.09%       3.867ms         3.16%       3.957ms       3.957ms       3.887ms         3.08%       3.953ms       3.953ms             1  
autograd::engine::evaluate_function: AddmmBackward0: AddmmBackward0         0.02%      24.000us         2.59%       3.236ms       3.236ms      43.000us         0.03%       3.231ms       3.231ms             1  
            AddmmBackward0         0.01%      18.000us         2.52%       3.153ms       3.153ms      47.000us         0.04%       3.147ms       3.147ms             1  
                   aten::t         0.01%       9.000us         2.36%       2.946ms       2.946ms      14.000us         0.01%       2.941ms       2.941ms             1  
           aten::transpose         0.01%       9.000us         2.34%       2.932ms       2.932ms       2.916ms         2.31%       2.927ms       2.927ms             1  
          aten::as_strided         0.01%      16.000us         2.33%       2.919ms       2.919ms      11.000us         0.01%      11.000us      11.000us             1  
           cudaEventCreate         2.31%       2.894ms         2.31%       2.894ms       2.894ms       0.000us         0.00%       0.000us       0.000us             1  
  Optimizer.step#Adam.step         0.22%     281.000us         1.60%       2.003ms       2.003ms     696.000us         0.55%       1.999ms       1.999ms             1  
  Optimizer.step#Adam.step         0.24%     296.000us         1.60%       2.001ms       2.001ms     698.000us         0.55%       1.996ms       1.996ms             1  
--------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 125.048ms
Self CUDA time total: 126.227ms


To isolate the bottleneck of your application I would recommend to create a profile showing the timeline of your workload.
Based on your explanation and workload I would guess your workload might be CPU-limited, i.e. the CPU isn’t fast enough to feed the GPU with the data and to launch the kernels.
E.g. using this code:

import torch
import torch.nn as nn
import time


class NeuralNetwork(torch.nn.Module):
    def __init__(self, in_feats, no_nodes):
        super(NeuralNetwork, self).__init__()
        self.stack = torch.nn.Sequential(
            torch.nn.Linear(in_feats, no_nodes),
            torch.nn.PReLU(),
            torch.nn.Linear(no_nodes,  no_nodes),
            torch.nn.PReLU(),
            torch.nn.Linear(no_nodes,  no_nodes),
            torch.nn.PReLU(),
            torch.nn.Linear(no_nodes,  no_nodes),
            torch.nn.PReLU(),
            torch.nn.Linear(no_nodes, 1),
        )

    def forward(self, x):
        return self.stack(x)

device = 'cpu'
N = 640
model = NeuralNetwork(N, N).to(device)
x = torch.randn(N, N, device=device)

nb_iters = 1000

# warmup
for _ in range(10):
    out = model(x)
    out.backward(torch.randn_like(out))
    model.zero_grad()

grad = torch.randn_like(out)

t0 = time.perf_counter()
for _ in range(nb_iters):
    out = model(x)
    out.backward(grad)
    model.zero_grad()
t1 = time.perf_counter()
print('cpu: {}'.format((t1 - t0)/nb_iters))


device = 'cuda'
model.to(device)
x = x.to(device)
# warmup
for _ in range(10):
    out = model(x)
    out.backward(torch.randn_like(out))
    model.zero_grad()

grad = torch.randn_like(out)

torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(nb_iters):
    out = model(x)
    out.backward(grad)
    model.zero_grad()
torch.cuda.synchronize()
t1 = time.perf_counter()
print('cuda: {}'.format((t1 - t0)/nb_iters))

I see:

# N=64
cpu: 0.0007186590805649758
cuda: 0.0013837001658976078

# N = 640
cpu: 0.010692139904946088
cuda: 0.0016351652033627033

As you can see, the GPU workload shows an offset for tiny workloads, which I would assume is caused by the kernel launches etc. (CUDA Graphs would help here).
After the workload is increased, the GPU runtime doesn’t show a large slowdown while the CPU slows down significantly.

Thanks for the reply!

I understand there is an offset in initialising the CUDA things, and I definitely see that however even in larger workloads i.e with 100s of epochs the time taken to do the training loop is still about the same for my laptop and the GPU (even though the profile shown was done with less epochs to save time).

I’m also not sure how the dataloading can be CPU limited when all the training and test data is moved to the GPU at the start and only random slices of the data are used in the training loop? Or is the indice generation the bottleneck?
Also what is especially strange to me that the cluster has far superior hardware than my bog normal laptop in all aspects (CPU, GPU, RAM, storage), so even with a CPU bottleneck I’d expect much faster speeds.

The data loading itself might be the bottleneck, but as explained in y previous post the actual workload on the GPU is tiny which is why the kernel launch overheads are visible (and which also could make your workload CPU-limited disallowing it to run ahead with the kernel scheduling). A quick check is to increase the actual workload (done in my code) or to profile it with e.g. Nsight Systems and see how the CPU and GPU workloads behave.

If your real workload is indeed defined as launching a lot of tiny kernels, check CUDA Graphs as it would help in this case (assuming it’s compatible with your use case).

Please profile the workload and share screenshots of the interesting training iterations.