Getting KeyError while trying to iterate through a dataset with Dataloader

Hi!

I have a dataframe called df with the columns called word and dict.

Firstly, I split this dataframe into train and test parts with help of train_test_split.

train_df, test_df = train_test_split(df, test_size=0.2, shuffle=True)

Then I created a custom dataset, which I then passed into a Dataloader, so as to iterate through the data in the train function of the machine learning model.

class MyDataset(Dataset):
    
    def __init__(self, df):       
        self.words = df['word']
        self.dicts = df['dict']
    
    def __len__(self):
        return len(self.words)
    
    def __getitem__(self, index):
        word_ = self.words[index]
        dict_ = self.dicts(index)
        
        return word_, dict_
train_dataset = MyDataset(train_df)
test_dataset = MyDataset(test_df)
train_loader = DataLoader(train_dataset, batch_size=10)
test_loader = DataLoader(test_dataset, batch_size=10)

The problem is that if I try to iterate through MyDataset I keep getting KeyErrors. It seems like the Dataloader tries to iterate through indexes of test_dataset, when it goes through train and vice versa. I tried to look for the solution of this problem, but have not found it anywhere.
Do you have an idea, how I can fix it?
Thank you in advance for your help!

The train_dataloader loop looks as follows (just for testing):

for x, y in train_loader:
    print(x, y)

The error I get is this:

KeyError                                  Traceback (most recent call last)
<ipython-input-93-dcdc8a11e4d8> in <module>
----> 1 for x, y in train_loader:
      2     print(x, y)

C:\Downloads\Anaconda\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    433         if self._sampler_iter is None:
    434             self._reset()
--> 435         data = self._next_data()
    436         self._num_yielded += 1
    437         if self._dataset_kind == _DatasetKind.Iterable and \

C:\Downloads\Anaconda\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
    473     def _next_data(self):
    474         index = self._next_index()  # may raise StopIteration
--> 475         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    476         if self._pin_memory:
    477             data = _utils.pin_memory.pin_memory(data)

C:\Downloads\Anaconda\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

C:\Downloads\Anaconda\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

C:\Downloads\Anaconda\lib\site-packages\pandas\core\frame.py in __getitem__(self, key)
   2798             if self.columns.nlevels > 1:
   2799                 return self._getitem_multilevel(key)
-> 2800             indexer = self.columns.get_loc(key)
   2801             if is_integer(indexer):
   2802                 indexer = [indexer]

C:\Downloads\Anaconda\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
   2646                 return self._engine.get_loc(key)
   2647             except KeyError:
-> 2648                 return self._engine.get_loc(self._maybe_cast_indexer(key))
   2649         indexer = self.get_indexer([key], method=method, tolerance=tolerance)
   2650         if indexer.ndim > 1 or indexer.size > 1:

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 0

The problem is that you cannot index a pandas dataframe. You should be able to do this instead

self.words.iloc[index]

Thanks for your answer! It solved the problem.