GPU is "only" 6x faster than CPU

I run the following code on CPU and GPU. I observe a disappointing speedup on GPU.

def build_dataset(words):
    block_size = 3
    X, Y = [], []
    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            context = context[1:] + [ix]

    X = torch.tensor(X).to(device)
    Y = torch.tensor(Y).to(device)
    return X, Y

import random
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

X_train, Y_train = build_dataset(words[:n1])
X_dev, Y_dev = build_dataset(words[n1:n2])
X_test, Y_test = build_dataset(words[n2:])

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27,10)).to(device)
W1 = torch.randn((30, 200)).to(device)
b1 = torch.randn((1, 200)).to(device)
W2 = torch.randn((200, 27)).to(device)
b2 = torch.randn((1, 27)).to(device)
parameters = [C, W1, b1, W2, b2]
for p in parameters:
    p.requires_grad = True

batch_size = 25000
for i in range(5):
    # minibatch
    ix = torch.randint(0, X_train.shape[0], (batch_size,))

    # forward pass
    emb = C[X[ix]]
    h = torch.tanh(emb.view(-1, 30) @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y_train[ix])
    # backward pass
    for p in parameters:
        p.grad = None
    # update
    lr = 0.01
    for p in parameters: += -lr * p.grad
    print(i, loss)

I am using an A4000 GPU and 8 vCPUs (AMD EPYC 7502, 2.5MHz). When I set device to ‘cuda’, the training code runs in 40ms. When I set the device to ‘cpu’, the same training code runs in 250ms.

I have set to batch_size to maximize GPU usage as per nvidia-smi. In CPU mode, htop shows that all eights vCPUs are used at full capacity.

I expected the GPU to bring much more speedup. Am I missing something?

Your GPU workload seems to be tiny besides the large indexing operation, so you might be CPU limited. Profile your use case with the native PyTorch profiler or Nsight Systems to understand the performance and bottlenecks of your code better. If this use case represents your real workload, you might want to check CUDA Graphs to reduce the CPU workload.

Thanks for your help!