I’ve defined following class for my Dataset, so that I can later feed it to DataLoader class to generate my training data.
from torch.utils import data
import random
import numpy as np
class Dataset(data.Dataset):
def __init__(self, qrel, query_vecs, offsets, fname):
self.qrel = qrel
self.query_vecs = query_vecs
self.offsets = offsets
self.f = open(fname, 'rt')
self.keys = offsets.keys()
def __len__(self):
return len(self.qrel)
def __iter__(self):
return iter(self.qrel)
def __next__(self):
return next(iter(self.qrel))
def __getitem__(self, key):
# Following line fetches a query vector from the dictionary
query = self.query_vecs[key]
# Following four lines are responsible for fetching doc vector for positive document.
# All doc vectors are stored as a tsv file
posid = self.qrel[key][0]
offset = self.offsets[posid]
self.f.seek(offset)
pos = np.fromstring(self.f.readline().split('\t')[1], sep=" ")
#Following lines are responsible for fetching randomly selected four
# Document vectors
negs = []
neg_keys = random.sample(self.keys, 4)
for key in neg_keys:
self.f.seek(self.offsets[key])
negs.append(np.fromstring(self.f.readline().split('\t')[1], sep=" "))
# Return tuple of query, positive, negative docs.
return (query, pos, np.array(negs))
I’m using this class as follows :
import torch
from torch.utils import data
import pickle
import random
import numpy as np
import gzip
import csv
from my_dataset import Dataset
qrel = {}
with gzip.open("msmarco-doctrain-qrels.tsv.gz", 'rt', encoding='utf8') as f:
tsvreader = csv.reader(f, delimiter="\t")
for [topicid, _, docid, rel] in tsvreader:
assert rel == "1"
if topicid in qrel:
qrel[topicid].append(docid)
else:
qrel[topicid] = [docid]
# query_vecs is a dictionary
with open('query_vecs.pickle', 'rb') as f:
query_vecs = pickle.load(f)
# Offsets is also a dictionary
with open('doc_vec_offsets.p', 'rb') as f:
offsets = pickle.load(f)
fname = 'doc_vecs.tsv'
params = {'batch_size': 1,
'shuffle': False,
'num_workers': 1}
max_epochs = 1
ds = Dataset(qrel,query_vecs, offsets, fname)
dataloader = data.DataLoader(ds, **params)
for epoch in range(max_epochs):
for sample in dataloader:
print(sample)
print(type(sample))
break
However, I’m getting following error. I don’t know why I’m getting this KeyError, because I’ve already replaced the __iter__()
and __next__()
method. What else should I do? Basically my dataset starts from that qrel dictionary, where it first finds the query vector, positive document vector and 4 negative doc vectors. However, I can not just iterate through that qrel in the class Dataset. I get following error and I’m stuck.
Exception: KeyError:Traceback (most recent call last):
File "/home/ruchit/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/ruchit/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/ruchit/Desktop/NLP/project/my_dataset.py", line 24, in __getitem__
query = self.query_vecs[self.qrel[key]]
KeyError: 0
It looks like even after I’ve defined iter() method, it still is replacing it with iterator corresponding to some list.
The problem seems to be with DataLoader class. I can iterate through object of class Dataset
, but after doing
ds = Dataset(qrel,query_vecs, offsets, fname)
dataloader = data.DataLoader(ds, **params)
dataloader is no longer iterable.
For example,
below code works fine
for sample in ds:
print(sample)
But, after passing it to DataLoader, it not longer works. Following thing fails :
for sample in dataloader:
print(sample)