Is pytorch1.10.0+cu113 slower than torch1.8.2+cu102?

I tried to run my training code with torch1.10.0+cu113, only to find that the speed is slower than that of torch1.8.2+cu102. Does anyone have the same problem?

from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import random


class BaseDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        assert len(images) == len(labels), 'image num must equal to label num!'
        self.images = images
        self.labels = labels
        self.transform = transform
        
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        return image, label
        
    def __len__(self):
        return len(self.images)


random.seed(2021)
np.random.seed(2021)
torch.manual_seed(2021)

# deterministic for multiprocess in dataloader
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    print(f"seed of process-{ worker_id }: { worker_seed }")
    random.seed(worker_seed)
    np.random.seed(worker_seed)


epoch_num = 1000
batch_size = 32


model = models.resnet50()
if torch.cuda.is_available():
    model.cuda()


optimizer = optim.Adam(model.parameters(), lr=1e-4)


if __name__ == '__main__':
    print("---start training...")

    # torch.backends.cudnn.benchmark = True  # use this for faster training
    torch.backends.cudnn.deterministic = True  # use this for more deterministic training

    for epoch in range(1, epoch_num + 1):
        # reset loss and item
        running_loss = 0.0
        ite_num4loss = 0
        
        images = np.random.randn(1600, 224, 224, 3)
        labels = np.random.randint(1000, size=(1600))
        
        dataset = BaseDataset(images, labels, 
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(
                                    mean=[0.484, 0.454, 0.403], 
                                    std=[0.225, 0.220, 0.220])]))
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, 
                                worker_init_fn=seed_worker)  # set num_workers = 0 in jupyter


        model.train()

        for i, data in tqdm(enumerate(dataloader)):
            ite_num4loss = ite_num4loss + 1
    
            inputs, labels = data
            inputs = inputs.to(torch.float32)
            labels = labels.to(torch.int64)
            
            if torch.cuda.is_available():
                inputs, labels = inputs.cuda(), labels.cuda()
    
            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            out = model(inputs)
            loss = F.cross_entropy(out, labels)
    
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            
            print(f"[epoch: {epoch:3d}/{epoch_num:3d}" + "," +
                f" batch: {(i+1)*batch_size:6d}/{len(dataset)}" + "," +
                f" train loss: {running_loss / ite_num4loss:3f}")

I test the above code with torch1.8.2+cu102 and torch1.10.0+cu113 respectively. The results are as follows:

  • torch1.8.2+cu102, 2.73 it/s, 19 s/epoch

  • torch1.10.0+cu113, 2.36 it/s, 25 s/epoch

You can easily run this code in your command line.

This is most likely not an issue with Pytorch (alone) but may be the result of other influences. Depending on the hardware being utilized and the packages being used/methods of their installation, it is unsurprising to see some differences in speed between different environments. As an example, a user using an RTX3090 will see some meaningful change in speed when performing specific operations when using cu111+ (libtorch_cuda.so is missing fast kernels from libcudnn_static.a, therefore statically linked cuDNN could be much slower than dynamically linked · Issue #50153 · pytorch/pytorch · GitHub).

The linked issue points to a build failure in the pip wheels and conda binaries when cuDNN is statically linked and which should not be visible anymore. It not directly relevant to the used library versions, as your 3090 would need CUDA>=11, so you would hit a runtime issue when trying to run the 1.8.2+cu102 binaries.

@hideinshadow could you post the output of python -m torch.utils.collect_env?

The output are as follows:


PyTorch version: 1.8.2+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Home
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.21.2

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2060
Nvidia driver version: 496.49
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] pytorch-ignite==0.4.7
[pip3] pytorch-nlp==0.5.0
[pip3] pytorch-ranger==0.1.1
[pip3] pytorchtools==0.0.2
[pip3] torch==1.8.2+cu102
[pip3] torch-lr-finder==0.2.1
[pip3] torch-optimizer==0.3.0
[pip3] torchaudio==0.8.2
[pip3] torchfile==0.1.0
[pip3] torchtext==0.9.2
[pip3] torchvision==0.9.2+cu102
[conda] numpy 1.20.3 pypi_0 pypi
[conda] pytorch-ignite 0.4.7 pypi_0 pypi
[conda] pytorch-nlp 0.5.0 pypi_0 pypi
[conda] pytorch-ranger 0.1.1 pypi_0 pypi
[conda] pytorchtools 0.0.2 pypi_0 pypi
[conda] torch 1.8.2+cu102 pypi_0 pypi
[conda] torch-lr-finder 0.2.1 pypi_0 pypi
[conda] torch-optimizer 0.3.0 pypi_0 pypi
[conda] torchaudio 0.8.2 pypi_0 pypi
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchtext 0.9.2 pypi_0 pypi
[conda] torchvision 0.9.2+cu102 pypi_0 pypi


PyTorch version: 1.10.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Home
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.21.2
Libc version: N/A

Python version: 3.8.12 (default, Oct 12 2021, 03:01:40) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19041-SP0
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2060
Nvidia driver version: 496.49
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.10.0+cu113
[pip3] torchtext==0.11.0
[pip3] torchvision==0.11.1+cu113
[conda] numpy 1.20.3 pypi_0 pypi
[conda] torch 1.10.0+cu113 pypi_0 pypi
[conda] torchtext 0.11.0 pypi_0 pypi
[conda] torchvision 0.11.1+cu113 pypi_0 pypi

Thanks for sharing this output!

I was able to reproduce a minor regression in the 1.10.0+cu113 wheels, which seems to be resolved already using the latest cuDNN release (8.3.0.96).
Code used to profile the workload:

random.seed(2021)
np.random.seed(2021)
torch.manual_seed(2021)

print(torch.cuda.get_device_name(0))
torch.backends.cudnn.benchmark = False

model = models.resnet50()
if torch.cuda.is_available():
    model.cuda()

optimizer = optim.Adam(model.parameters(), lr=1e-4)

if __name__ == '__main__':
    print("---start training...")

    # torch.backends.cudnn.benchmark = True  # use this for faster training
    torch.backends.cudnn.deterministic = False  # use this for more deterministic training
        
    inputs = torch.randn(32, 3, 224, 224).cuda()
    labels = torch.randint(1000, size=(32,)).cuda()

    # warmup
    for i in range(10): 
        # zero the parameter gradients
        optimizer.zero_grad()
    
        # forward + backward + optimize
        out = model(inputs)
        loss = F.cross_entropy(out, labels)
    
        loss.backward()
        optimizer.step()
    
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for i in range(100): 
        # zero the parameter gradients
        optimizer.zero_grad()
    
        # forward + backward + optimize
        out = model(inputs)
        loss = F.cross_entropy(out, labels)
    
        loss.backward()
        optimizer.step()
    torch.cuda.synchronize()
    t1 = time.perf_counter()

    print((t1 - t0)/100)

Output:

1.8.2+cu102
deterministic = True
0.13085028601810336

deterministic = False
0.12192453876137734

benchmark = True
0.11950120938941837


1.10.0+cu113
deterministic = True
0.16201625971123576

deterministic = False
0.15164822606369854

benchmark = True
0.13653683779761194


source build + cuDNN8.3.0.96
deterministic = True
0.14494292575865983

deterministic = False
0.11826244482770562

benchmark = True
0.11710769269615412

You could either build from source using the current cuDNN release or wait until the wheels will be updated.

2 Likes

Thanks for your reply!