Issue with custom Dataset class for DataLoader

I’ve defined following class for my Dataset, so that I can later feed it to DataLoader class to generate my training data.

from torch.utils import data
import random
import numpy as np

class Dataset(data.Dataset):
    def __init__(self, qrel, query_vecs, offsets, fname):
        self.qrel = qrel
        self.query_vecs = query_vecs
        self.offsets = offsets
        self.f = open(fname, 'rt')
        self.keys = offsets.keys()
        
    def __len__(self):
        return len(self.qrel)
    
    def __iter__(self):
        return iter(self.qrel)
    
    def __next__(self):
        return next(iter(self.qrel))
    
    def __getitem__(self, key):
        # Following line fetches a query vector from the dictionary
        query = self.query_vecs[key]
        
        # Following four lines are responsible for fetching doc vector for positive document.
       # All doc vectors are stored as a tsv file
        posid = self.qrel[key][0]
        offset = self.offsets[posid]
        self.f.seek(offset)
        pos = np.fromstring(self.f.readline().split('\t')[1], sep=" ")
        
       #Following lines are responsible for fetching randomly selected four 
      # Document vectors
        negs = []
        neg_keys = random.sample(self.keys, 4)
        for key in neg_keys:
            self.f.seek(self.offsets[key])
            negs.append(np.fromstring(self.f.readline().split('\t')[1], sep=" "))
            
        # Return tuple of query, positive, negative docs.  
        return (query, pos, np.array(negs))

I’m using this class as follows :

import torch
from torch.utils import data
import pickle
import random
import numpy as np
import gzip
import csv
from my_dataset import Dataset

qrel = {}
with gzip.open("msmarco-doctrain-qrels.tsv.gz", 'rt', encoding='utf8') as f:
    tsvreader = csv.reader(f, delimiter="\t")
    for [topicid, _, docid, rel] in tsvreader:
        assert rel == "1"
        if topicid in qrel:
            qrel[topicid].append(docid)
        else:
            qrel[topicid] = [docid]
      
# query_vecs is a dictionary      
with open('query_vecs.pickle', 'rb') as f:
    query_vecs = pickle.load(f)

# Offsets is also a dictionary    
with open('doc_vec_offsets.p', 'rb') as f:
    offsets = pickle.load(f)
    
fname = 'doc_vecs.tsv'


params = {'batch_size': 1,
          'shuffle': False,
          'num_workers': 1}

max_epochs = 1
ds = Dataset(qrel,query_vecs, offsets, fname)
dataloader = data.DataLoader(ds, **params)

for epoch in range(max_epochs):
    for sample in dataloader:
        print(sample)
        print(type(sample))
        break

However, I’m getting following error. I don’t know why I’m getting this KeyError, because I’ve already replaced the __iter__() and __next__() method. What else should I do? Basically my dataset starts from that qrel dictionary, where it first finds the query vector, positive document vector and 4 negative doc vectors. However, I can not just iterate through that qrel in the class Dataset. I get following error and I’m stuck.


Exception: KeyError:Traceback (most recent call last):
  File "/home/ruchit/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/ruchit/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/ruchit/Desktop/NLP/project/my_dataset.py", line 24, in __getitem__
    query = self.query_vecs[self.qrel[key]]
KeyError: 0

It looks like even after I’ve defined iter() method, it still is replacing it with iterator corresponding to some list.
The problem seems to be with DataLoader class. I can iterate through object of class Dataset, but after doing

ds = Dataset(qrel,query_vecs, offsets, fname)
dataloader = data.DataLoader(ds, **params)

dataloader is no longer iterable.

For example,
below code works fine

for sample in ds:
    print(sample)

But, after passing it to DataLoader, it not longer works. Following thing fails :

for sample in dataloader:
    print(sample)

Check if next(iter(x)) inside your __next__ is causing problem , because essentially for each call to next, you are re-initializing iterator.

At first, I did not implement the next method at all, but I was getting the same error. I thought implementing it should remove it. But should I remove the next method at all from my code?

I did one hack here. Instead of iterating over dictionary, I iterated over its keys. Then I don’t have to implement the iter() pattern any more. Its working now, but there is a IndexError when I run it for 8 workers. For one worker, it works fine.

Have you tried it without the __iter__ and __next__ methods? Generally, you are only supposed to only override __getitem__ for datasets as in the tutorial.

Yes, as I mentioned in previous comment, I just iterated over dictionary keys, where I removed the iter and next methods. I managed to solve this problem. But after this, I am not able to set num_workers=8, it only works for num_workers=1. When I set it to 8, I get the IndexError. See this for details. I guess I should close this ticket if it’s possible. I’ll check if I can do that.