I am using PyTorch on a Kuberneted pod running 20.04, some version numbers follow:
- pytorch
- pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0
- pytorch-mutex=1.0=cuda
- torchaudio=0.11.0=py38_cu113
- torchmetrics=0.8.2=pyhd8ed1ab_0
- torchvision=0.12.0=py38_cu113
The dataloader, using four workers, randomly freezes. After interrupting the process, I read these messages:
File "/MYPACKAGES/encoder_trainers.py", line 271, in compute_predictions
for batch in dl:
File "/opt/conda/envs/torch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 530, in __next__
data = self._next_data()
File "/opt/conda/envs/torch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1207, in _next_data
idx, data = self._get_data()
File "/opt/conda/envs/torch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1163, in _get_data
success, data = self._try_get_data()
File "/opt/conda/envs/torch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1011, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/opt/conda/envs/torch/lib/python3.8/queue.py", line 179, in get
self.not_empty.wait(remaining)
File "/opt/conda/envs/torch/lib/python3.8/threading.py", line 306, in wait
gotit = waiter.acquire(True, timeout)
KeyboardInterrupt
This is my custom dataset, with all the prints for debug messages. All the prints are printed, meaning that the flow exits from the dataset and goes to the dataloader. The item on which the dataloader freezes is completely random.
class ImageDataset(Dataset):
def __init__(self, dataset: pd.DataFrame, img_fld: str, img_transforms=None, n_classes=None, img_size=224, l1normalization=False):
super().__init__()
self.n_classes = n_classes or dataset.shape[1]
assert self.n_classes == dataset.shape[1]
print("IMAGE DATASET, NCLASSES SET TO", self.n_classes)
self.ds = dataset
self.img_fld = img_fld
self.transforms = img_transforms
self.img_size = img_size
self.l1normalization = l1normalization
if l1normalization:
print("IMAGE DATASET, WARNING: l1 normalization of target values set to TRUE")
else:
print("IMAGE DATASET: l1 normalization of target values set to FALSE")
def __len__(self):
return len(self.ds)
#< len
def __getitem__(self, idx):
print("Accessing item:", idx)
assert (idx >=0) and (idx < len(self.ds))
item = self.ds.iloc[idx]
filename = item.name
labels = item.values
print("\t filename:", filename)
print("\t labels:", sum(labels))
if self.l1normalization:
labels = labels / labels.sum()
image = self.load_image(filename)
print("image loaded")
labels = torch.tensor(labels.astype(np.float32))
print('about to return image and labels')
return image, labels
#< getitem
def load_image(self, img_filename):
fn = join(self.img_fld, img_filename)
img = Image.open(fn)
print("ok, read image:", img_filename)
if self.transforms is not None:
img = self.transforms(img)
print("returning image")
return img
#< load_image
#< ImageDataset
The Dataset is managed by a dataloader created as it follows:
DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4, drop_last=[False, False, False], pin_memory=False)
As a further note, the dataloaders for train, validation and test are kept in a dictionary:
dataloaders = {'train': DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=4, drop_last=[False, False, False], pin_memory=False), 'valid': DataLoader(valid_dataset, batch_size=128, shuffle=False, num_workers=4, drop_last=[False, False, False], pin_memory=False), 'test': DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4, drop_last=[False, False, False], pin_memory=False)}
In one run these are the last messages, then it got frozen:
Accessing item: 1151
filename: CXR730_IM-2290-1002.png
labels: 1
ok, read image: CXR730_IM-2290-1002.png
returning image
image loaded
about to return image and labels
in another run:
Accessing item: 767
filename: CXR61_IM-2197-2001.png
labels: 1
ok, read image: CXR61_IM-2197-2001.png
returning image, shape: torch.Size([3, 224, 224])
image loaded
about to return image and labels
This freezing started when I reorganized my code. moving from simple Python scripts to a set of (base) classes and subclasses. The dataloader is always managed in a base class, that is never directly instantiated, only subclasses are.
Other info:
- the freezing occurs more often with smaller batch sizes, e.g. 32, than with larger ones, e.g. 128.
- it occurs very often, but not always, on the same image (1151 above). The problem is not related to the specific image file because It is read, processed, and resized.
It never freezes with dataloaders created with num_workers=0.