DataLoader worker crashes if using pinned CPU data

Consider the following totally minimalistic stripped down code:

import torch.utils.data  # Version 1.12.1
import torchvision  # Version 0.13.1
import torchvision.transforms as transforms
import wandb  # Version 0.13.10

wandb.init(mode='disabled')
tfrm = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.1307,), std=(0.3081,))])
train_dataset = torchvision.datasets.MNIST(root='.', train=True, transform=tfrm, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, num_workers=4, shuffle=True, pin_memory=True, drop_last=True)

for epoch in range(10):
	print(f"EPOCH {epoch}")
	for i, (data_cpu, target_cpu) in enumerate(train_loader):
		if i < 10:
			wandb_image = wandb.Image(data_cpu)
		data = data_cpu.to('cuda', non_blocking=True)
		target = target_cpu.to('cuda', non_blocking=True)
		calc = (data * target[:, None, None, None]).cpu()  # Just any computation on GPU

This crashes the DataLoader after a random number of epochs (at most 5 epochs in my testing so far) with:

Traceback (most recent call last):
  File "/home/allgeuer/anaconda3/envs/relish/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1163, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/home/allgeuer/anaconda3/envs/relish/lib/python3.9/queue.py", line 180, in get
    self.not_empty.wait(remaining)
  File "/home/allgeuer/anaconda3/envs/relish/lib/python3.9/threading.py", line 316, in wait
    gotit = waiter.acquire(True, timeout)
  File "/home/allgeuer/anaconda3/envs/relish/lib/python3.9/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 2387021) is killed by signal: Aborted. 

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/allgeuer/Code/ReLish/benchmark/error/dataloader_error.py", line 16, in <module>
    for i, (data_cpu, target_cpu) in enumerate(train_loader):
  File "/home/allgeuer/anaconda3/envs/relish/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 681, in __next__
    data = self._next_data()
  File "/home/allgeuer/anaconda3/envs/relish/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1359, in _next_data
    idx, data = self._get_data()
  File "/home/allgeuer/anaconda3/envs/relish/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1315, in _get_data
    success, data = self._try_get_data()
  File "/home/allgeuer/anaconda3/envs/relish/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1176, in _try_get_data
    raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 2387021) exited unexpectedly

So there must be some kind of a race condition and/or bug. Or have I performed an ‘illegal’ operation by accessing a pinned tensor for CPU computations?

The error stops happening if (any of):

  • I comment out the wandb image line
  • I set num_workers=0 instead of 4
  • I set pin_memory=False
  • I use the CPU instead of CUDA device

The error continues to happen even if (any of):

  • I don’t disable wandb on init (that’s just convenience)
  • I add del wandb_image after the wandb image line
  • I move the if/wandb.Image to the end of the loop (after CPU->GPU transfer)

What is NOT the problem:

  • I am on Ubuntu, not Windows
  • I am nowhere near running out of GPU memory

What is going on?

I cannot reproduce the issue and get:

EPOCH 0
EPOCH 1
EPOCH 2
EPOCH 3
EPOCH 4
EPOCH 5
EPOCH 6
EPOCH 7
EPOCH 8
EPOCH 9

as the output.

Good to know. While trying to produce a minimal example for cloning and running my full code off GitHub to better reproduce my issue I noticed that the error seems to be specific to my conda environment.

I am unable to easily recreate that conda environment afresh to test, because packages have been updated since then and it was a conda/mim/pip mix, but so far it continues to be that only the original conda environment (consistently) displays the error, and so far the new ones not at all. Hopefully it stays that way :wink: