This has been bugging me the last few days. I have narrowed this issue down (I hope) to my custom dataset, that subclasses the torch.utils.data.Dataset
.
The model trains fine with num_workers=0 in the dataloader. Different to all the other topics, I can train with num_workers>0, but get weird results: instead of converging to ~70% accuracy and corresponding loss, the model converges to ~20% accuracy, and also weird numbers in validation.
- My code has a
if __name__ == "__main__"
safeguard and is simply called withpython main.py
- I only change the
num_workers
parameter in the training dataloader - I can train fine with other datasets, say
torchvision.datasets.MNIST
- My custom dataset has the following
__getitem__
function:
def __getitem__(self, idx):
sample = json.loads(self.data[idx])
x = torch.tensor(sample["x"])
label = sample["y"]
return x, label
which doesn’t look like something could go wrong in the num_workers>0 multiprocessing. Any tips? For now, I can work with num_workers=0
, but I am taking a performance hit, I would rather not have.