Significantly Different Results Using cuDNN

Hi,

I noticed a significant difference in the final results of AlexNet depending on whether I enable cuDNN or not, and whether I set the batch size to 1 or 100. When cuDNN is enabled and the batch size is set to 100, the difference between the GPU and CPU results is larger than 10%. In addition, the difference between the GPU results with a batch size of 1 and a batch size of 100 is also greater than 10%. Some outliers even show a much larger difference. I’ve tested this with different PyTorch versions and on various NVIDIA GPUs (including Google Colab), so it doesn’t seem to depend on the cluster hardware.

Am I missing something? If I set my model to evaluation mode, shouldn’t the results be more or less the same? (I understand that some differences can occur due to floating-point precision and the GPU vs. CPU, but are the differences supposed to be this large?)

Thank you very much!

Here the code or you can reproduce it directly on Google Colab.
This is the code for the GPU, with cuDNN enabled and float32 as dtype.

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.models import get_weight, alexnet
import numpy as np

print(torch.__version__)
seed = 42
num_workers = 0
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = True

FLOAT_64 = False
DEVICE = "cuda"
pretrained = True

# Set datatype to float32
if FLOAT_64:
    torch.set_default_dtype(torch.float64)
else:
    torch.set_default_dtype(torch.float32)
    torch.set_float32_matmul_precision("high")

VERSION = "IMAGENET1K_V1"
version_id = f"AlexNet_Weights.{VERSION}"

# Load the pretrained weights
weights = get_weight(version_id)

# Load model and preprocessing transformations
if pretrained:
    model = alexnet(weights=weights, progress=True)
else:
    model = alexnet(weights=None)

transforms = weights.transforms()

class DummyDataset(Dataset):
     def __init__(self):
          self.data = torch.randn(200, 3, 227, 227)

     def __len__(self):
          return len(self.data)

     def __getitem__(self, index):
          tmp_data = self.data[index]
          tmp_data = transforms(tmp_data)
          if FLOAT_64:
             tmp_data = tmp_data.to(torch.float64)
          return tmp_data

device = torch.device(DEVICE)

print(device)

model.to(device)
model.eval()

dataset = DummyDataset()
dataloader_batch16 = DataLoader(dataset, batch_size=100, shuffle=False, num_workers=num_workers)
for batch_idx, data in enumerate(dataloader_batch16):
  data = data.to(device)
  with torch.no_grad(): 
      data_o = model(data)
      data_o = data_o.cpu()
  break

dataloader_batch1 = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=num_workers)
store_results = torch.zeros((100, 1000), device = "cpu")

for batch_idx, data in enumerate(dataloader_batch1):          
  data = data.to(device)
  with torch.no_grad(): 
      store_results[batch_idx, :] = model(data).cpu()

  if batch_idx == 99:
      break

torch.allclose(data_o, store_results, rtol=0.1)

Gives ‘False’.

In the following, the code for CPU (run after the first code snippet) with float32:

# Run the same on CPU

DEVICE = "cpu"

# Set datatype to float32
if FLOAT_64:
    torch.set_default_dtype(torch.float64)
else:
    torch.set_default_dtype(torch.float32)
    torch.set_float32_matmul_precision("high")

VERSION = "IMAGENET1K_V1"
version_id = f"AlexNet_Weights.{VERSION}"

# Load the pretrained weights
weights = get_weight(version_id)

# Load model and preprocessing transformations
if pretrained:
    model = alexnet(weights=weights, progress=True)
else:
    model = alexnet(weights=None)

transforms = weights.transforms()

device = torch.device(DEVICE)

print(device)

model.to(device)
model.eval()

for batch_idx, data in enumerate(dataloader_batch16):
  data = data.to(device)
  with torch.no_grad(): 
      data_o_cpu = model(data)
  break

store_results_cpu = torch.zeros((100, 1000), device = "cpu")

for batch_idx, data in enumerate(dataloader_batch1):          
  data = data.to(device)
  with torch.no_grad(): 
      store_results_cpu[batch_idx, :] = model(data).cpu()

  if batch_idx == 99:
      break

torch.allclose(data_o_cpu, store_results_cpu, rtol=0.1)

Gives ‘True’.

# Compare results between GPU and CPU
# Batch size 100
torch.allclose(data_o, data_o_cpu, rtol=0.1)
# Batch size 1
torch.allclose(store_results, store_results_cpu, rtol=0.1)

Gives ‘False’ for batch size 100 and ‘True’ for batch size 1.

As said, these differences do not occur if cuDNN is not enabled OR if float64 is used OR on the CPU.

Hi everyone, I just wanted to follow up and see if anyone might have any ideas regarding my question. I would really appreciate any feedback. Thanks in advance! :slight_smile:

this is not that surprising.

The differences you’re seeing are likely because cuDNN is using different algorithms to compute convs and matmuls, that might be differently accumulating results – and floating point associativity non-determinism does the rest.

If reproducibility is a must, try this:

Set torch.backends.cudnn.deterministic = True and torch.backends.cudnn.benchmark = False. It forces PyTorch to use deterministic (but maybe slower) operations.

1 Like

(I understand that some differences can occur due to floating-point precision and the GPU vs. CPU, but are the differences supposed to be this large?)

yes, the differences can be pretty large for a deeper neural network – as the differences from the beginning of the network snowball all the way to the end.
They are also more pronounced when algorithmic differences exist – such as using Toeplitz-based convs on CPU or cublas; and Winograd-based convs on CuDNN for certain kernel sizes. Winograd based convs are more imprecise in a strict algorithmic sense.

1 Like

Thank you very much for your reply! I greatly appreciate it!

Set torch.backends.cudnn.deterministic = True and torch.backends.cudnn.benchmark = False. It forces PyTorch to use deterministic (but maybe slower) operations.

I have already done this in the code example (with the result described above). It doesn’t seem to have any effect. I have highlighted this again in the Google Colab notebook. As mentioned, the batch size appears to have a significant impact on the final result. With a batch size of 100 and cuDNN enabled (even with deterministic algorithms and benchmarking disabled), the relative difference of some elements is more than 600% (rtol = 6 evaluates to ‘False’). This only occurs with cuDNN enabled and a batch size of 100. Note that in my experiment (not shown/described here), this had a highly significant impact on the final task results.

So, just to confirm:

  1. This large difference in this case – i.e., the combination of cuDNN and a batch size of 100 – is not due to a bug?
  2. Why does it not occur with a batch size of 1 or on the CPU? Is it because different algorithms are used?
  3. How can one handle reproducibility in this context if setting a seed, using deterministic algorithms, and disabling benchmarking doesn’t help?

Thanks again!

Okay, I’ve looked at your code closely. Basically, even with deterministic algorithms, computing batch-wise vs computing individually and concatenating – the results via CuDNN are probably different and that’s not inconsistent.

For example, if this is a basic Conv2D implemented in two different ways to parallelize:

import torch

batch_size = 2
in_channels = 1
out_channels = 2
height, width = 6, 6
kernel_size = 3

input_tensor = torch.randn(batch_size, in_channels, height, width)
kernel = torch.randn(out_channels, in_channels, kernel_size, kernel_size)

out_height = height - kernel_size + 1
out_width = width - kernel_size + 1

output_tensor = torch.zeros((batch_size, out_channels, out_height, out_width))

# Parallelized over batch dimension
for b in range(batch_size):
    for oc in range(out_channels):
        for i in range(out_height):
            for j in range(out_width):
                patch = input_tensor[b, :, i:i+kernel_size, j:j+kernel_size]
                output_tensor[b, oc, i, j] = torch.sum(patch * kernel[oc])

print("Output (Parallelized over batch dimension):")
print(output_tensor)

# Parallelized over out_height dimension
output_tensor_height_parallel = torch.zeros((batch_size, out_channels, out_height, out_width))

for b in range(batch_size):
    for oc in range(out_channels):
        partial_sum_buffer = torch.zeros(out_height, out_width)
        for i in range(out_height):
            for j in range(out_width):
                patch = input_tensor[b, :, i:i+kernel_size, j:j+kernel_size]
                partial_sum_buffer[i, j] = torch.sum(patch * kernel[oc])
        output_tensor_height_parallel[b, oc] = partial_sum_buffer

print("Output (Parallelized over out_height dimension):")
print(output_tensor_height_parallel)

If you notice, the first implementation vs the second implementation – they are exactly the same algorithm, but when computed on the GPU with each outer loop on a separate core, and then reconciled, then they will have slightly different output because of floating point additions being done in different order.

When you are doing the batch case, it’s possible that CuDNN is using the second algorithm, and when you are doing the single sample case (and then constructing the concatenated tensor) it is using the first algorithm.
They are equivalent in logic, but are different in their floating point error accumulation.

and to clarify determinism. PyTorch (and CuDNN) only guarantee determinism of operations of the same shape (across GPU types). Once you change batch size and manually split+concatenate, CuDNN doesn’t guarantee that it’ll use the same algorithm at batch size = 1 vs batch size = 100.