The documentation for IterableDataset is very clear that:
When
num_workers > 0
, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers.
We see this again in this answer here: What data does each worker process hold? Does it hold the full dataset object or only a batch of it?
Where it says that each worker will have a reference to its own dataset.
And here it is more explicitly framed as: Is a Dataset copied as part of Dataloader with multiple workers?
I will frame my use-case and later describe what I’m curious about here.
I am using python’s ZipFile for a dataset of mine. With ZipFile, by experimentation, it seems that reading the same Zip Archive (on disk) concurrently with multiple ZipFile objects is okay. However, if a jointly shared ZipFile object is trying to be read from concurrently, this quickly fails.
By default, if I were to use a Dataset with a ZipFile archive object that is shared, the getitems will quickly throw errors related to the ZipFile (usually failed CRC checks). An example of this is this dataset:
from torch.utils.data import Dataset
from torchvision.io import decode_image
import torch
import zipfile
import os
class ImageArchive(Dataset):
def __init__(self, archive_path: str) -> None:
super().__init__()
self.archive = zipfile.ZipFile(os.path.abspath(archive_path), "r")
self.image_paths = [file for file in self.archive.infolist() if not file.is_dir()]
def __len__(self):
return len(self.archive.filelist)
def __getitem__(self, index):
# All flavors of read fail, even when copying to a new object :/
imgBytes = bytearray(self.archive.read(self.image_paths[index].filename))
img_tensor = torch.frombuffer(imgBytes, dtype=torch.uint8)
return decode_image(img_tensor)
When run with a dataloader with num_workers > 0:
dataset = ImageArchive('/scr/data/MorphEm/idr0003.zip')
train_loader = DataLoader(dataset=dataset, batch_size=128, num_workers=2)
for index, _ in enumerate(train_loader):
if index > 5:
break
Will throw a 'BadZipFile` error with a traceback that usually starts with: “zipfile.BadZipFile: Bad CRC-32 for file …”. Notably, this is not an error with the zip archive on disk. Running this code with num_workers=0 completes as expected.
This is explained by the comment I made above with multiple threads reading from a zip archive with the same object. This is the shared, self.archive
in this case.
I bring up IterableDataset here because reading the documentation, it almost seems like each worker, with num_workers > 0 should be instantiating its own dataset. Notably, when running this as an IterativeDataset, we see that printing out the workers of a dataloader with num_workers > 0 gives Datasets of the same memory address.
WorkerInfo(id=0, num_workers=2, seed=1164168763176036335, dataset=<__main__.ImageArchive object at 0x7f7171cd0fa0>)
WorkerInfo(id=1, num_workers=2, seed=1164168763176036336, dataset=<__main__.ImageArchive object at 0x7f7171cd0fa0>)
This causes simple approaches with IterableDataset to fail with the same error as described above when using the normal Dataset.
I cannot deep copy these objects because of the buffer that gets opened with ZipFile (nor can they be pickled for this reason) but even taking a shallow copy shows different memory addresses:
from copy import copy
dataset = ImageArchive('/scr/data/MorphEm/idr0003.zip')
print(dataset)
print(copy(dataset))
Shows:
<__main__.ImageArchive object at 0x7f720f560820>
<__main__.ImageArchive object at 0x7f71270ef400>
Where clearly the memory addresses are no longer the same. Though I don’t know if a separate shallow copy per worker would allow for concurrent reads. That is beyond what I’ve tried so far.
I am able to work around my issue with ZipFile and IterableDataset not creating discrete copies of my dataset by using an IterableDataset worker_init_fn
and having it instantiate new datasets per worker when num_workers > 0. This gives different ZipFIle objects, and concurrent reads for multiple workers reading this data works as expected.
However, I would like to know why IterableDataset says that each of my workers should be getting their own copy of the dataset, but when printing out the workers, all signs point towards them sharing the same object because they point to the same memory address and seem to be reading from the same ZipFile object.