Extremely slow training speed on RTX 5070 Laptop GPU despite latest PyTorch & CUDA setup

Hi PyTorch community,

I’m experiencing unusually slow training performance on my new NVIDIA GeForce RTX 5070 Laptop GPU. Even with the latest PyTorch nightly (2.9.0.dev20250704+cu128) and CUDA 12.8 installed, my training speed on GPU is much slower than expected—about 10 seconds per batch for a simple ResNet-18 model on CIFAR-10 with batch size 512. The CPU training time for the same batch size is also about 10 seconds, which makes no sense. Worse still, it takes 20 seconds to run without manual optimization on my GPU. I have never installed older versions of cuda.

Environment details:

Torch version         : 2.9.0.dev20250704+cu128
CUDA available        : True
Torch CUDA version    : 12.8
CUDA device count     : 1
Device 0 name        : NVIDIA GeForce RTX 5070 Laptop GPU
Compute capability   : sm_120
Supported arch list  : ['sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120']
Total memory (MB)    : 8150
Basic GPU compute    : OK

cuDNN & TF32 Settings:
cudnn.benchmark      : False
TF32 matmul allowed  : False
TF32 cudnn allowed   : True

I tested torch.compile() with both default and reduce-overhead modes. The default mode compiled and ran successfully, but the reduce-overhead mode failed with an overflow error:

Testing with compile mode='reduce-overhead':
reduce-overhead mode failed: Python int too large to convert to C long

Code snippet for training benchmark:

import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import torch.backends.cudnn as cudnn
import torch.amp as amp

def main():
    # Select device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # —— Global performance flags ——  
    cudnn.benchmark = True                       # Enable cuDNN autotuner  
    torch.backends.cuda.matmul.allow_tf32 = True # Enable TF32 for matmul  
    torch.backends.cudnn.allow_tf32 = True       # Enable TF32 for cuDNN

    # —— Data loading setup ——  
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5, 0.5))
    ])
    train_dataset = datasets.CIFAR10(
        root='.', train=True, download=True, transform=transform
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=512,
        shuffle=True,
        num_workers=8,            # Number of data-loading workers  
        pin_memory=True,          # Use pinned memory for faster transfers  
        persistent_workers=True,  # Keep workers alive between epochs  
        prefetch_factor=2         # Number of batches to prefetch per worker  
    )

    # —— Model creation & JIT compile ——  
    model = models.resnet18(weights=None, num_classes=10).to(device)
    model = torch.compile(model, backend="eager")  # Avoid auto-tuning paths that don’t benefit

    # —— Optimizer, loss, and mixed precision setup ——  
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()
    scaler = amp.GradScaler()

    # —— Training loop (20 batches) ——  
    model.train()
    print("\n📘 Starting training for 20 batches")

    batch_times = []
    for batch_idx, (images, labels) in enumerate(train_loader):
        if batch_idx >= 20:
            break

        print(f"\n🟢 Batch {batch_idx}")
        t0 = time.time()

        # Move data to device
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        torch.cuda.synchronize()
        print(f"  🔸 Data transfer: {time.time() - t0:.3f}s")

        optimizer.zero_grad()

        # Forward pass + loss
        t1 = time.time()
        with amp.autocast(device_type='cuda'):
            outputs = model(images)
            loss = loss_fn(outputs, labels)
        torch.cuda.synchronize()
        print(f"  🔸 Forward + loss: {time.time() - t1:.3f}s")

        # Backward pass + optimizer step
        t2 = time.time()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        torch.cuda.synchronize()
        print(f"  🔸 Backward + step: {time.time() - t2:.3f}s")

        print(f"  ✅ Loss: {loss.item():.4f}")
        batch_times.append(time.time() - t0)

    avg_time = sum(batch_times) / len(batch_times) if batch_times else 0.0
    print(f"\n✅ Average batch time: {avg_time:.3f}s")

if __name__ == "__main__":
    main()

Result:

📘 Starting training for 20 batches

🟢 Batch 0
  🔸 Data transfer: 0.819s
  🔸 Forward + loss: 5.622s
  🔸 Backward + step: 7.883s
  ✅ Loss: 2.4147

🟢 Batch 1
  🔸 Data transfer: 0.022s
  🔸 Forward + loss: 0.246s
  🔸 Backward + step: 1.893s
  ✅ Loss: 2.3446

🟢 Batch 2
  🔸 Data transfer: 0.021s
  🔸 Forward + loss: 4.790s
  🔸 Backward + step: 6.619s
  ✅ Loss: 2.8196

🟢 Batch 3
  🔸 Data transfer: 0.022s
  🔸 Forward + loss: 4.910s
  🔸 Backward + step: 6.618s
  ✅ Loss: 2.4265

🟢 Batch 4
  🔸 Data transfer: 0.022s
  🔸 Forward + loss: 4.827s
  🔸 Backward + step: 6.633s
  ✅ Loss: 2.1797

🟢 Batch 5
  🔸 Data transfer: 0.022s
  🔸 Forward + loss: 4.835s
  🔸 Backward + step: 6.681s
  ✅ Loss: 2.0355

🟢 Batch 6
  🔸 Data transfer: 0.022s
  🔸 Forward + loss: 4.976s
  🔸 Backward + step: 6.618s
  ✅ Loss: 1.9456

🟢 Batch 7
  🔸 Data transfer: 0.022s
  🔸 Forward + loss: 4.819s
  🔸 Backward + step: 6.621s
  ✅ Loss: 1.8671

🟢 Batch 8
  🔸 Data transfer: 0.023s
  🔸 Forward + loss: 4.803s
  🔸 Backward + step: 6.612s
  ✅ Loss: 1.8940

🟢 Batch 9
  🔸 Data transfer: 0.022s
  🔸 Forward + loss: 4.798s
  🔸 Backward + step: 6.638s
  ✅ Loss: 1.8690

🟢 Batch 10
  🔸 Data transfer: 0.021s
  🔸 Forward + loss: 4.828s
  🔸 Backward + step: 6.638s
  ✅ Loss: 1.7905

🟢 Batch 11
  🔸 Data transfer: 0.022s
  🔸 Forward + loss: 4.859s
  🔸 Backward + step: 6.957s
  ✅ Loss: 1.8161

🟢 Batch 12
  🔸 Data transfer: 0.022s
  🔸 Forward + loss: 5.010s
  🔸 Backward + step: 6.836s
  ✅ Loss: 1.6921

🟢 Batch 13
  🔸 Data transfer: 0.021s
  🔸 Forward + loss: 4.990s
  🔸 Backward + step: 6.869s
  ✅ Loss: 1.7362

🟢 Batch 14
  🔸 Data transfer: 0.022s
  🔸 Forward + loss: 4.996s
  🔸 Backward + step: 6.884s
  ✅ Loss: 1.7184

🟢 Batch 15
  🔸 Data transfer: 0.022s
  🔸 Forward + loss: 5.033s
  🔸 Backward + step: 6.869s
  ✅ Loss: 1.7337

... ...

✅ Average batch time: 11.345s

Questions:
Is there any known issue or additional configuration needed for RTX 5070 Laptop GPU (sm_120) to achieve expected performance?

Are my cuDNN and TF32 settings optimal? Should cudnn.benchmark and TF32 matmul be enabled?

Could PyTorch nightly 2.9.0 + CUDA 12.8 have missing or incomplete optimizations for this new GPU architecture?

Any debugging tips or profiling tools recommended to identify bottlenecks?

Thanks in advance for your help!

I don’t have a 5070 handy, but see this performance on a 5090 using your code and the latest nightly binaries with CUDA 12.8:

📘 Starting training for 20 batches

🟢 Batch 0
  🔸 Data transfer: 0.014s
  🔸 Forward + loss: 1.920s
  🔸 Backward + step: 2.325s
  ✅ Loss: 2.3930

🟢 Batch 1
  🔸 Data transfer: 0.007s
  🔸 Forward + loss: 0.030s
  🔸 Backward + step: 0.056s
  ✅ Loss: 2.8687

🟢 Batch 2
  🔸 Data transfer: 0.007s
  🔸 Forward + loss: 0.030s
  🔸 Backward + step: 0.056s
  ✅ Loss: 2.3795

🟢 Batch 3
  🔸 Data transfer: 0.008s
  🔸 Forward + loss: 0.030s
  🔸 Backward + step: 0.056s
  ✅ Loss: 2.2824
...
🟢 Batch 18
  🔸 Data transfer: 0.013s
  🔸 Forward + loss: 0.030s
  🔸 Backward + step: 0.057s
  ✅ Loss: 1.6297

🟢 Batch 19
  🔸 Data transfer: 0.010s
  🔸 Forward + loss: 0.030s
  🔸 Backward + step: 0.057s
  ✅ Loss: 1.7118

✅ Average batch time: 0.305s
1 Like

I have a RTX 5070, this was my run.
(pygpu) oba@mail:~/code/llamaserv$ python testrun.py
Using device: cuda
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 170M/170M [00:20<00:00, 8.14MB/s]

:blue_book: Starting training for 20 batches

:green_circle: Batch 0
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 2.585s
:small_orange_diamond: Backward + step: 1.645s
:white_check_mark: Loss: 2.4328

:green_circle: Batch 1
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.093s
:small_orange_diamond: Backward + step: 0.167s
:white_check_mark: Loss: 2.9187

:green_circle: Batch 2
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.092s
:small_orange_diamond: Backward + step: 0.167s
:white_check_mark: Loss: 2.8075

:green_circle: Batch 3
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.093s
:small_orange_diamond: Backward + step: 0.168s
:white_check_mark: Loss: 2.3828

:green_circle: Batch 4
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.093s
:small_orange_diamond: Backward + step: 0.168s
:white_check_mark: Loss: 2.1121

:green_circle: Batch 5
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.091s
:small_orange_diamond: Backward + step: 0.168s
:white_check_mark: Loss: 2.0617

:green_circle: Batch 6
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.093s
:small_orange_diamond: Backward + step: 0.167s
:white_check_mark: Loss: 2.0539

:green_circle: Batch 7
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.092s
:small_orange_diamond: Backward + step: 0.168s
:white_check_mark: Loss: 2.0069

:green_circle: Batch 8
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.091s
:small_orange_diamond: Backward + step: 0.167s
:white_check_mark: Loss: 1.9361

:green_circle: Batch 9
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.093s
:small_orange_diamond: Backward + step: 0.168s
:white_check_mark: Loss: 1.8847

:green_circle: Batch 10
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.092s
:small_orange_diamond: Backward + step: 0.167s
:white_check_mark: Loss: 1.8330

:green_circle: Batch 11
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.091s
:small_orange_diamond: Backward + step: 0.168s
:white_check_mark: Loss: 1.7778

:green_circle: Batch 12
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.093s
:small_orange_diamond: Backward + step: 0.168s
:white_check_mark: Loss: 1.8241

:green_circle: Batch 13
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.091s
:small_orange_diamond: Backward + step: 0.168s
:white_check_mark: Loss: 1.7903

:green_circle: Batch 14
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.093s
:small_orange_diamond: Backward + step: 0.167s
:white_check_mark: Loss: 1.7902

:green_circle: Batch 15
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.092s
:small_orange_diamond: Backward + step: 0.167s
:white_check_mark: Loss: 1.7115

:green_circle: Batch 16
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.092s
:small_orange_diamond: Backward + step: 0.168s
:white_check_mark: Loss: 1.7007

:green_circle: Batch 17
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.093s
:small_orange_diamond: Backward + step: 0.168s
:white_check_mark: Loss: 1.6957

:green_circle: Batch 18
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.092s
:small_orange_diamond: Backward + step: 0.168s
:white_check_mark: Loss: 1.6608

:green_circle: Batch 19
:small_orange_diamond: Data transfer: 0.013s
:small_orange_diamond: Forward + loss: 0.091s
:small_orange_diamond: Backward + step: 0.168s
:white_check_mark: Loss: 1.6665

:white_check_mark: Average batch time: 0.471s
(pygpu) oba@mail:~/code/llamaserv$

(pygpu) oba@mail:~/code/llamaserv$ nvidia-smi
Thu Jul 17 14:31:30 2025
±----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.57.08 Driver Version: 575.57.08 CUDA Version: 12.9 |
|-----------------------------------------±-----------------------±---------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 5070 On | 00000000:01:00.0 On | N/A |
| 30% 43C P1 20W / 250W | 680MiB / 12227MiB | 7% Default |
| | | N/A |
±----------------------------------------±-----------------------±---------------------+

I’ve reinstalled the NVIDIA Studio driver (v576.80). Here’s the updated nvidia-smi output:

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 576.80                 Driver Version: 576.80         CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 5070 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   52C    P4             13W /   40W |    7742MiB /   8151MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

However, my training iteration time is still about 3 seconds, which is still much slower. I’m not sure what causes this slowdown. Could it be related to Windows system settings or something else?

I’d appreciate any suggestions or tips on how to debug this. Thanks!

Hi everyone, thanks for all the suggestions so far!

My GPU has 8 GB of VRAM, so I tried running a clean version of the code (no AMP, no TF32 flags, no non-blocking transfers, no cudnn.benchmark, no torch.compile) to isolate the effect of batch size:

import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # ——— Data transforms & loader ———
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
    ])
    train_dataset = datasets.CIFAR10(root='.', train=True, download=True, transform=transform)
    train_loader  = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=400,           # ← first test with 256, then 400
        shuffle=True,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )

    # ——— Model, optimizer, loss ———
    model    = models.resnet18(weights=None, num_classes=10).to(device)
    optimizer= optim.Adam(model.parameters(), lr=0.001)
    loss_fn  = nn.CrossEntropyLoss()

    model.train()
    print("\n📘 Starting training for 20 batches\n")

    batch_times = []
    for batch_idx, (images, labels) in enumerate(train_loader):
        if batch_idx >= 20:
            break

        t0 = time.time()
        images = images.to(device)   # only .to(device)
        labels = labels.to(device)

        torch.cuda.synchronize()     # sync for accurate timing

        optimizer.zero_grad()
        outputs = model(images)
        loss    = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        torch.cuda.synchronize()     # end sync

        batch_time = time.time() - t0
        batch_times.append(batch_time)
        print(f"Batch {batch_idx}: Loss={loss.item():.4f}, Time={batch_time:.3f}s")

    avg_time = sum(batch_times)/len(batch_times)
    print(f"\n✅ Average batch time: {avg_time:.3f}s")

if __name__ == "__main__":
    main()

Results

  • batch_size=256
Using device: cuda

📘 Starting training for 20 batches
Batch 0: Loss=2.3960, Time=2.272s
Batch 1: Loss=2.7107, Time=0.585s
Batch 2: Loss=2.5974, Time=0.557s
Batch 3: Loss=2.1467, Time=0.536s
Batch 4: Loss=2.1897, Time=0.520s
Batch 5: Loss=2.0713, Time=0.517s
Batch 6: Loss=1.9585, Time=0.512s
Batch 7: Loss=1.9605, Time=0.511s
Batch 8: Loss=2.0029, Time=0.511s
Batch 9: Loss=1.8113, Time=0.510s
Batch 10: Loss=1.7202, Time=0.511s
Batch 11: Loss=1.7713, Time=0.512s
Batch 12: Loss=1.7673, Time=0.510s
Batch 13: Loss=1.8245, Time=0.511s
Batch 14: Loss=1.7219, Time=0.511s
Batch 15: Loss=1.7513, Time=0.512s
Batch 16: Loss=1.7893, Time=0.511s
Batch 17: Loss=1.7394, Time=0.511s
Batch 18: Loss=1.7888, Time=0.511s
Batch 19: Loss=1.7004, Time=0.510s

✅ Average batch time: 0.607s
  • batch_size=400
Using device: cuda

📘 Starting training for 20 batches
Batch 0: Loss=2.4176, Time=14.423s
Batch 1: Loss=2.8280, Time=19.925s
Batch 2: Loss=2.9932, Time=19.657s
Batch 3: Loss=2.6546, Time=19.473s
Batch 4: Loss=2.2218, Time=19.839s
Batch 5: Loss=2.1040, Time=19.643s
Batch 6: Loss=2.0553, Time=19.474s
Batch 7: Loss=2.0001, Time=19.561s
Batch 8: Loss=1.9294, Time=19.484s
Batch 9: Loss=1.8814, Time=19.755s
Batch 10: Loss=1.8365, Time=19.595s
Batch 11: Loss=1.9479, Time=19.704s
Batch 12: Loss=1.8728, Time=19.697s
Batch 13: Loss=1.8680, Time=19.529s
Batch 14: Loss=1.7708, Time=19.609s
Batch 15: Loss=1.8914, Time=19.958s
Batch 16: Loss=1.7299, Time=19.593s
Batch 17: Loss=1.6805, Time=19.510s
Batch 18: Loss=1.7659, Time=19.452s
Batch 19: Loss=1.7543, Time=19.528s

✅ Average batch time: 19.370s

Even with all of my previous manual optimizations enabled (AMP, non-blocking copies, TF32, cudnn.benchmark, torch.compile, etc.), I always see ~3 s per batch when batch_size=512. As soon as I drop to 256, it snaps back to ~0.5–0.6 s.

Initial Conclusion: For now, my working assumption is that, on an 8 GB RTX 5070 Laptop GPU, batch_size=512 is simply too big to fit or to run efficiently in FP32+AMP, causing huge stalls or OOMs. Reducing batch_size to 256 (or lower) is an effective way to get sub-second performance on this hardware.

If anyone has further suggestions on how to maximize training speed under tight VRAM constraints, I’d really appreciate your input!

I’ve already experimented with:

  • Automatic Mixed Precision (AMP)
  • TF32 mode
  • non_blocking=True on .to()
  • cudnn.benchmark = True
  • torch.compile()
  • Memory pinning and prefetching in the DataLoader

Are there other techniques—maybe less common—that could help reduce memory overhead or improve throughput without sacrificing too much accuracy?

Thanks again for all the great insights so far!

Thanks for the additional tests! Indeed you might be running OOM on your GPU but since you are using Windows the OS will offload device data to the host behind your back causing a large slowdown. You could disable this Windows feature and check if a proper OOM error message would be raised.