The DataLoader class is hanging (or crashing) in Windows but not in Linux with the following example:
#Demo of DataLoader crashing in Windows and with Visual Studio Code
import torch
from torch.utils.data import Dataset, DataLoader
class SimpleData(Dataset):
"""Very simple dataset"""
def __init__(self):
self.data = range(20)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
#Create dataset
myDataSet = SimpleData()
#put dataset into DataLoader:
dataloader = DataLoader(myDataSet, batch_size=4, num_workers=1)
print('Using DataLoader to show data: ')
for i, sample_batched in enumerate(dataloader):
print('batch ', i, ':', sample_batched)
print("--- Done ---")
When run in python in Liux I get, this outputs:
Using DataLoader to show data:
batch 0 : tensor([0, 1, 2, 3])
batch 1 : tensor([4, 5, 6, 7])
batch 2 : tensor([ 8, 9, 10, 11])
batch 3 : tensor([12, 13, 14, 15])
batch 4 : tensor([16, 17, 18, 19])
--- Done ---
Now when I insert an if name==‘main’: line, it works fine, like this:
#Demo of DataLoader working
import torch
from torch.utils.data import Dataset, DataLoader
class SimpleData(Dataset):
"""Very simple dataset"""
def __init__(self):
self.data = range(20)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
#This crashes unless I include this line:
if __name__ == '__main__':
#Create dataset
myDataSet = SimpleData()
#put dataset into DataLoader:
dataloader = DataLoader(myDataSet, batch_size=4, num_workers=1)
print('Using DataLoader to show data: ')
for i, sample_batched in enumerate(dataloader):
print('batch ', i, ':', sample_batched)
print("--- Done ---")
Can anyone explain what is wrong with the top example? Why is it necessary to check for main? I am getting some other strange behavior with DataLoader, and so want to at least understand this.