Hi,
I noticed a significant difference in the final results of AlexNet depending on whether I enable cuDNN or not, and whether I set the batch size to 1 or 100. When cuDNN is enabled and the batch size is set to 100, the difference between the GPU and CPU results is larger than 10%. In addition, the difference between the GPU results with a batch size of 1 and a batch size of 100 is also greater than 10%. Some outliers even show a much larger difference. I’ve tested this with different PyTorch versions and on various NVIDIA GPUs (including Google Colab), so it doesn’t seem to depend on the cluster hardware.
Am I missing something? If I set my model to evaluation mode, shouldn’t the results be more or less the same? (I understand that some differences can occur due to floating-point precision and the GPU vs. CPU, but are the differences supposed to be this large?)
Thank you very much!
Here the code or you can reproduce it directly on Google Colab.
This is the code for the GPU, with cuDNN enabled and float32 as dtype.
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.models import get_weight, alexnet
import numpy as np
print(torch.__version__)
seed = 42
num_workers = 0
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = True
FLOAT_64 = False
DEVICE = "cuda"
pretrained = True
# Set datatype to float32
if FLOAT_64:
torch.set_default_dtype(torch.float64)
else:
torch.set_default_dtype(torch.float32)
torch.set_float32_matmul_precision("high")
VERSION = "IMAGENET1K_V1"
version_id = f"AlexNet_Weights.{VERSION}"
# Load the pretrained weights
weights = get_weight(version_id)
# Load model and preprocessing transformations
if pretrained:
model = alexnet(weights=weights, progress=True)
else:
model = alexnet(weights=None)
transforms = weights.transforms()
class DummyDataset(Dataset):
def __init__(self):
self.data = torch.randn(200, 3, 227, 227)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
tmp_data = self.data[index]
tmp_data = transforms(tmp_data)
if FLOAT_64:
tmp_data = tmp_data.to(torch.float64)
return tmp_data
device = torch.device(DEVICE)
print(device)
model.to(device)
model.eval()
dataset = DummyDataset()
dataloader_batch16 = DataLoader(dataset, batch_size=100, shuffle=False, num_workers=num_workers)
for batch_idx, data in enumerate(dataloader_batch16):
data = data.to(device)
with torch.no_grad():
data_o = model(data)
data_o = data_o.cpu()
break
dataloader_batch1 = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=num_workers)
store_results = torch.zeros((100, 1000), device = "cpu")
for batch_idx, data in enumerate(dataloader_batch1):
data = data.to(device)
with torch.no_grad():
store_results[batch_idx, :] = model(data).cpu()
if batch_idx == 99:
break
torch.allclose(data_o, store_results, rtol=0.1)
Gives ‘False’.
In the following, the code for CPU (run after the first code snippet) with float32:
# Run the same on CPU
DEVICE = "cpu"
# Set datatype to float32
if FLOAT_64:
torch.set_default_dtype(torch.float64)
else:
torch.set_default_dtype(torch.float32)
torch.set_float32_matmul_precision("high")
VERSION = "IMAGENET1K_V1"
version_id = f"AlexNet_Weights.{VERSION}"
# Load the pretrained weights
weights = get_weight(version_id)
# Load model and preprocessing transformations
if pretrained:
model = alexnet(weights=weights, progress=True)
else:
model = alexnet(weights=None)
transforms = weights.transforms()
device = torch.device(DEVICE)
print(device)
model.to(device)
model.eval()
for batch_idx, data in enumerate(dataloader_batch16):
data = data.to(device)
with torch.no_grad():
data_o_cpu = model(data)
break
store_results_cpu = torch.zeros((100, 1000), device = "cpu")
for batch_idx, data in enumerate(dataloader_batch1):
data = data.to(device)
with torch.no_grad():
store_results_cpu[batch_idx, :] = model(data).cpu()
if batch_idx == 99:
break
torch.allclose(data_o_cpu, store_results_cpu, rtol=0.1)
Gives ‘True’.
# Compare results between GPU and CPU
# Batch size 100
torch.allclose(data_o, data_o_cpu, rtol=0.1)
# Batch size 1
torch.allclose(store_results, store_results_cpu, rtol=0.1)
Gives ‘False’ for batch size 100 and ‘True’ for batch size 1.
As said, these differences do not occur if cuDNN is not enabled OR if float64 is used OR on the CPU.