Hi!
First off all, I am reading posts and github issues and threads since a few hours. I learned that Multithreading on Windows and/or Jupyter (Google colab) seams to be a pain or not working at all.
After a lot of trial and error, following a lot of advice it seams to work now for me, giving me an immense speed improvement. But sadly only with a downloaded Dataset. If I try it with my own it freezes up immidiatly and I have to restart the runtime, all meanwhile the CPU sits at 0%
I also read a lot of posts that are similiar to this one, and none seam to have the same issue. I believe there is an error in my DataSet class that I am unable to find. I hope someone can help me out.
This Test works perfectly fine:
import time
from tqdm import tqdm
import torch
import torchvision
def train(data_loader):
start = time.time()
for _ in tqdm(range(10)):
for x in data_loader:
pass
end = time.time()
return end - start
if __name__ == '__main__':
train_dataset = torchvision.datasets.FashionMNIST(
root=".", train=True, download=True,
transform=torchvision.transforms.ToTensor()
)
batch_size = 32
train_loader1 = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=0)
train_loader2 = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=8)
train_loader3 = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=8, persistent_workers=True)
train_loader4 = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=10, persistent_workers=True)
print(train(train_loader1))
print(train(train_loader2))
print(train(train_loader3))
print(train(train_loader4))
this gets stuck immediatly after data_loader1, so as soon as I switch on multiple workers. It freezes and the CPU chills at 0%.
It looks like the iterator just never returnes a value? Similar to Multiple Dataloader Workers in multi Threading ?
import time
from torch.utils.data import DataLoader # Gives easier dataset managment by creating mini batches etc.
from tqdm import tqdm # For nice progress bar!
import numpy as np
def train(data_loader):
start = time.time()
for _ in tqdm(range(10)):
for x in data_loader:
pass
end = time.time()
return end - start
if __name__ == '__main__':
batch_size = 64
root_dir = R'[REDACTED]\MiniStoneShaderDataset'
dataset = smallShaderDataset(csv_file='Labels.csv', root_dir=root_dir )
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [99, 9900])
train_loader1 = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory= True)
train_loader2 = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory= True, persistent_workers = True)
train_loader3 = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory= True, persistent_workers = True)
print(train(train_loader1))
print(train(train_loader2))
print(train(train_loader3))
Using this DataSet Class:
import pandas as pd
import torch
from torch.utils.data import Dataset
import numpy as np
#from skimage import io
class smallShaderDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
self.annotations = pd.read_csv(os.path.join(root_dir,csv_file))
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
data_path = os.path.join(self.root_dir,'LearnDataCombined', self.annotations.iloc[index,0] + ".npy")
input = np.load(data_path)
parameters = self.annotations.iloc[index,1:]
parameters = np.array([parameters],dtype = float).flatten()
input = np.array(input,dtype = float)
sample = {'input': input, 'parameters': parameters}
return input, parameters
#return sample
Everything works (slowly) with num_workers = 0
There is probably something wrong here wich i am unable to find in full tunnelvision mode. I really hope someone can help me find the issue.
If it really is a Bug in DataSet, why does the downloaded set work fine?
Thank you for any help!
Greetings