I have a Dataset class as follows. I want to generate batches of size 64 using DataLoader of torch.
from torch.utils import data
import random
import numpy as np
import torch
class Dataset(data.Dataset):
def __init__(self, qrel, query_vecs, offsets, fname):
self.qrel = qrel
self.qrel_list = list(qrel)
self.query_vecs = query_vecs
self.offsets = offsets
self.f = open(fname, 'rt')
self.keys = offsets.keys()
self.skipped = 0
def __len__(self):
return len(self.qrel)
def __getitem__(self, index):
key = self.qrel_list[index]
query = self.query_vecs[key].reshape(1, 300)
posid = self.qrel[key][0]
offset = self.offsets[posid]
self.f.seek(offset)
pos = np.fromstring(self.f.readline().split('\t')[1], sep=" ").reshape(1, 300)
neg_keys = random.sample(self.keys, 4)
negs = []
for neg in neg_keys:
self.f.seek(self.offsets[neg])
negs.append(np.fromstring(self.f.readline().split('\t')[1], sep=" "))
query = torch.tensor(query).float()
pos = torch.tensor(pos).float()
negs = torch.tensor(np.array(negs).reshape(4, 300)).float()
return (query, pos, negs)
I initialise this Dataset as follows.
qrel = {}
with gzip.open("msmarco-doctrain-qrels.tsv.gz", 'rt', encoding='utf8') as f:
tsvreader = csv.reader(f, delimiter="\t")
for [topicid, _, docid, rel] in tsvreader:
assert rel == "1"
if topicid in qrel:
qrel[topicid].append(docid)
else:
qrel[topicid] = [docid]
with open('query_vecs.pickle', 'rb') as f:
query_vecs = pickle.load(f)
with open('doc_vec_offsets.p', 'rb') as f:
offsets = pickle.load(f)
fname = 'doc_vecs.tsv'
params = {'batch_size': 16,
'shuffle': False,
'num_workers': 8,
'drop_last': True}
ds = Dataset(qrel, query_vecs, offsets, fname)
dl = data.DataLoader(ds, **params)
If I run following script, it works fine for 1 worker, but if I set num_workers=8
, in DataLoader, I get IndexError. What is the reason for this error?
for batch in dl:
...
train_model(batch)
When running it with 8 workers, i get this error.
Traceback (most recent call last):
File "/home/ruchit/Desktop/NLP/project/playground.py", line 49, in <module>
for s in dl:
File "/home/ruchit/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 582, in __next__
return self._process_next_batch(batch)
File "/home/ruchit/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 608, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
IndexError: Traceback (most recent call last):
File "/home/ruchit/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/ruchit/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/ruchit/Desktop/NLP/project/my_dataset.py", line 26, in __getitem__
pos = np.fromstring(f.readline().split('\t')[1], sep=" ").reshape(1, 300)
IndexError: list index out of range
Can someone help me here?