I have met this weird error when num_workers
is set to 1 or larger. Here is my MWE to reproduce:
from PIL import Image
import torch
from torchvision import transforms
class FrameDataset(torch.utils.data.Dataset):
def __init__(self, pil_imgs):
super(FrameDataset, self,).__init__()
self.pil_imgs = pil_imgs
self.transform = make_transform()
def __len__(self):
return len(self.pil_imgs)
def __getitem__(self, idx):
img = self.pil_imgs[idx]
return self.transform(img)
def make_transform():
trans = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
])
return trans
def main():
im = Image.open("../data/debug_img/S04000000003v5Rw0Nr9y7R___0000.jpg")
pil_imgs = [im]
num_round = 2
for _ in range(num_round):
val_dataset = FrameDataset(pil_imgs)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=1,
num_workers=1)
for im_batch in val_loader:
continue
if __name__ == "__main__":
main()
Okay, the actual image does not matter, as long as I set num_workers=1
, for any image, the above script will error out something like this:
OSError: image file is truncated ( 11 bytes not processed)
Yes, I have read this post and this post. But I do not think setting ImageFile.LOAD_TRUNCATED_IMAGES = True
is the right way to solve the issue, because the image is not broken or corrupt at all.
I also noticed that if I set num_round
to 1, the error will disappear.
So in summary, I have the following table
num_round | num_workers | error or not? |
---|---|---|
1 | 1 | no error |
1 | 0 | no error |
2 | 1 | error |
2 | 0 | no error |