KeyError when enumerating over dataloader - why?

I am writing a binary classification model that consists of audio files of 40 participants and classifies them according to whether they have a speech disorder or not. The audio files have been divided into 5 second segments and to avoid subject bias, I have split the training/testing/validation sets such that a subject only appears in one set (i.e. participant ID02 does not appear in both the training and testing sets). The following error appears when I attempt to enumerate over the DataLoader validLoader in the code below and I’m not entirely sure why this error is occurring. Does anyone have any advice?

KeyError                                  Traceback (most recent call last)
<ipython-input-69-55be99283cf7> in <module>()
----> 1 for i, data in enumerate(valid_loader, 0):
      2   images, labels = data
      3   print("Batch", i, "size:", len(images))

3 frames
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    361 
    362     def __next__(self):
--> 363         data = self._next_data()
    364         self._num_yielded += 1
    365         if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    987             else:
    988                 del self._task_info[idx]
--> 989                 return self._process_data(data)
    990 
    991     def _try_put_index(self):

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1012         self._try_put_index()
   1013         if isinstance(data, ExceptionWrapper):
-> 1014             data.reraise()
   1015         return data
   1016 

/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
    393             # (https://bugs.python.org/issue2651), so we work around it.
    394             msg = KeyErrorMessage(msg)
--> 395         raise self.exc_type(msg)

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-44-245be0a1e978>", line 19, in __getitem__
    x = Image.open(self.df['path'][index])
  File "/usr/local/lib/python3.6/dist-packages/pandas/core/series.py", line 871, in __getitem__
    result = self.index.get_value(self, key)
  File "/usr/local/lib/python3.6/dist-packages/pandas/core/indexes/base.py", line 4405, in get_value
    return self._engine.get_value(s, k, tz=getattr(series.dtype, "tz", None))
  File "pandas/_libs/index.pyx", line 80, in pandas._libs.index.IndexEngine.get_value
  File "pandas/_libs/index.pyx", line 90, in pandas._libs.index.IndexEngine.get_value
  File "pandas/_libs/index.pyx", line 138, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 998, in pandas._libs.hashtable.Int64HashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 1005, in pandas._libs.hashtable.Int64HashTable.get_item
KeyError: 36

Can anyone advise why this is happening?

from google.colab import drive
drive.mount('/content/drive')

import torch
import torchvision
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import utils
from  torch.utils.data import Dataset

from sklearn.metrics import confusion_matrix
from skimage import io, transform, data
from skimage.color import rgb2gray

import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import pandas as pd
import numpy as np
import csv
import os
import math
import cv2

root_dir = "/content/drive/My Drive/Read_Text/5_Second_Segments/"
class_names = [
  "Parkinsons_Disease",
  "Healthy_Control"
]

def get_meta(root_dir, dirs):
    """ Fetches the meta data for all the images and assigns labels.
    """
    paths, classes = [], []
    for i, dir_ in enumerate(dirs):
        for entry in os.scandir(root_dir + dir_):
            if (entry.is_file()):
                paths.append(entry.path)
                classes.append(i)
                
    return paths, classes


paths, classes = get_meta(root_dir, class_names)

data = {
    'path': paths,
    'class': classes
}

data_df = pd.DataFrame(data, columns=['path', 'class'])
data_df = data_df.sample(frac=1).reset_index(drop=True) # Shuffles the data

from pandas import option_context

print("Found", len(data_df), "images.")

with option_context('display.max_colwidth', 400):
    display(data_df.head(100))

class Audio(Dataset):

    def __init__(self, df, transform=None):
        """
        Args:
            image_dir (string): Directory with all the images
            df (DataFrame object): Dataframe containing the images, paths and classes
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        # Load image from path and get label
        x = Image.open(self.df['path'][index])
        try:
          x = x.convert('RGB') # To deal with some grayscale images in the data
        except:
          pass
        y = torch.tensor(int(self.df['class'][index]))

        if self.transform:
            x = self.transform(x)

        return x, y

def compute_img_mean_std(image_paths):
    """
        Author: @xinruizhuang. Computing the mean and std of three channel on the whole dataset,
        first we should normalize the image from 0-255 to 0-1
    """

    img_h, img_w = 224, 224
    imgs = []
    means, stdevs = [], []

    for i in tqdm(range(len(image_paths))):
        img = cv2.imread(image_paths[i])
        img = cv2.resize(img, (img_h, img_w))
        imgs.append(img)

    imgs = np.stack(imgs, axis=3)
    print(imgs.shape)

    imgs = imgs.astype(np.float32) / 255.

    for i in range(3):
        pixels = imgs[:, :, i, :].ravel()  # resize to one row
        means.append(np.mean(pixels))
        stdevs.append(np.std(pixels))

    means.reverse()  # BGR --> RGB
    stdevs.reverse()

    print("normMean = {}".format(means))
    print("normStd = {}".format(stdevs))
    return means, stdevs

norm_mean, norm_std = compute_img_mean_std(paths)

data_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std),
    ])

unique_users = data_df['path'].str[-20:-16].unique()
train_users, test_users = np.split(np.random.permutation(unique_users), [int(0.8*len(unique_users))])
df_train = data_df[data_df['path'].str[-20:-16].isin(train_users)]
test_data_df = data_df[data_df['path'].str[-20:-16].isin(test_users)]

train_unique_users = df_train['path'].str[-20:-16].unique()
train_users, validate_users = np.split(np.random.permutation(train_unique_users), [int(0.875*len(train_unique_users))])
train_data_df = df_train[df_train['path'].str[-20:-16].isin(train_users)]
valid_data_df = df_train[df_train['path'].str[-20:-16].isin(validate_users)]

ins_dataset_train = Audio(
    df=train_data_df,
    transform=data_transform,
)

ins_dataset_valid = Audio(
    df=valid_data_df,
    transform=data_transform,
)

ins_dataset_test = Audio(
    df=test_data_df,
    transform=data_transform,
)

train_loader = torch.utils.data.DataLoader(
    ins_dataset_train,
    batch_size=8,
    shuffle=True,
    num_workers=2
)

test_loader = torch.utils.data.DataLoader(
    ins_dataset_test,
    batch_size=16,
    shuffle=True,
    num_workers=2
)

valid_loader = torch.utils.data.DataLoader(
    ins_dataset_valid,
    batch_size=16,
    shuffle=True,
    num_workers=2
)

//(This is where the error is occurring.)
for i, data in enumerate(valid_loader, 0):
  images, labels = data
  print("Batch", i, "size:", len(images))

Can you try replacing

train_data_df = df_train[df_train['path'].str[-20:-16].isin(train_users)]
valid_data_df = df_train[df_train['path'].str[-20:-16].isin(validate_users)]

with

train_data_df = df_train[df_train['path'].str[-20:-16].isin(train_users)].reset_index(drop=True)
valid_data_df = df_train[df_train['path'].str[-20:-16].isin(validate_users)].reset_index(drop=True)

Thank you! That’s solved my problem! :slight_smile: