Enumerate Error Using GroupShuffleSplit to divide training set into training+validation sets

I am trying to use a custom CNN to classify spectrogram images generated for 3s audio segments. I am using GroupShuffleSplit to divide the training dataset into a training set and a validation set and to ensure that each participant is only included in one set (to prevent data leakage). I am using a Custom Dataset, as follows:

from torch.utils.data import DataLoader, Dataset, random_split
import torchaudio

# ----------------------------
# Sound Dataset
# ----------------------------
class SoundDS(Dataset):
  def __init__(self, df):
    self.df = df
    self.duration = 3000
    self.sr = 44100
    self.channel = 2
    self.shift_pct = 0.4
            
  # ----------------------------
  # Number of items in dataset
  # ----------------------------
  def __len__(self):
    return len(self.df)    
    
  # ----------------------------
  # Get i'th item in dataset
  # ----------------------------
  def __getitem__(self, idx):
    # Absolute file path of the audio file - concatenate the audio directory with
    # the relative path
 
    # print(self.df.loc[idx, 'relative_path'])
    audio_file = self.df.loc[idx, "relative_path"]
    class_id = self.df.loc[idx, "dx"]
    # participant_id = self.df.loc[idx, 'adressfname']
    # file_name = self.df.loc[idx, 'file_name']

    aud = AudioUtil.open(audio_file)
    print("here")
    # Some sounds have a higher sample rate, or fewer channels compared to the
    # majority. So make all sounds have the same number of channels and same 
    # sample rate. Unless the sample rate is the same, the pad_trunc will still
    # result in arrays of different lengths, even though the sound duration is
    # the same.
    reaud = AudioUtil.resample(aud, self.sr)
    rechan = AudioUtil.rechannel(reaud, self.channel)

    dur_aud = AudioUtil.pad_trunc(rechan, self.duration)
    shift_aud = AudioUtil.time_shift(dur_aud, self.shift_pct)
    sgram = AudioUtil.spectro_gram(shift_aud, n_mels=64, n_fft=1024, hop_len=None)
    aug_sgram = AudioUtil.spectro_augment(sgram, max_mask_pct=0.1, n_freq_masks=2, n_time_masks=2)

    return aug_sgram, class_id

If I generate a SoundDS object using the training set, randomly split the object into a training and validation subset, and pass these subsets into two respective data loaders, my model trains without any issues

from torch.utils.data import random_split

myds = SoundDS(train_df)

# Random split of 80:20 between training and validation
num_items = len(myds)
num_train = round(num_items * 0.8)
num_val = num_items - num_train
train_ds, val_ds = random_split(myds, [num_train, num_val])

# Create training and validation data loaders
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=16, shuffle=False)

However, if I initially use GroupShuffleSplit on the dataframe train_df (grouping by the column addressfname), then generate two SoundDS objects and pass the train_DS and validation_DS into two respective data loaders, I encounter the error message shown at the bottom of this post when I try to run for i, data in enumerate(train_dl):

from sklearn.model_selection import GroupShuffleSplit

splitter = GroupShuffleSplit(test_size=0.15, n_splits=1, random_state = 7)
split = splitter.split(train_df, groups=train_df['adressfname'])
train_inds, valid_inds = next(split)

train_data_df = train_df.iloc[train_inds]
valid_data_df = train_df.iloc[valid_inds]

train_dataset_DS = SoundDS(train_data_df)
valid_dataset_DS = SoundDS(valid_data_df)
train_dl = torch.utils.data.DataLoader(train_dataset_DS, batch_size=16, shuffle=True)
val_dl = torch.utils.data.DataLoader(valid_dataset_DS, batch_size=16, shuffle=False)

def training(model, train_dl, num_epochs):
  # Loss Function, Optimizer and Scheduler
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
  scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001,
                                                steps_per_epoch=int(len(train_dl)),
                                                epochs=num_epochs,
                                                anneal_strategy='linear')

  # Repeat for each epoch
  for epoch in range(num_epochs):
    running_loss = 0.0
    correct_prediction = 0
    total_prediction = 0

    # Repeat for each batch in the training set
    for i, data in enumerate(train_dl): ---- This is where the issue occurs
        # Get the input features and target labels, and put them on the GPU
        inputs, labels = data[0].to(device), data[1].to(device) 

        # Normalize the inputs
        inputs_m, inputs_s = inputs.mean(), inputs.std()
        inputs = (inputs - inputs_m) / inputs_s
KeyError                                  Traceback (most recent call last)
/usr/local/lib/python3.8/dist-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   3360             try:
-> 3361                 return self._engine.get_loc(casted_key)
   3362             except KeyError as err:

15 frames
/usr/local/lib/python3.8/dist-packages/pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

/usr/local/lib/python3.8/dist-packages/pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.Int64HashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.Int64HashTable.get_item()

KeyError: 3170

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
<ipython-input-201-b60077dcefb4> in <module>
     54 
     55 num_epochs=2   # Just for demo, adjust this higher.
---> 56 training(myModel, train_dl, num_epochs)

<ipython-input-201-b60077dcefb4> in training(model, train_dl, num_epochs)
     15 
     16     # Repeat for each batch in the training set
---> 17     for i, data in enumerate(train_dl):
     18         # Get the input features and target labels, and put them on the GPU
     19         inputs, labels = data[0].to(device), data[1].to(device)

/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    626                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    627                 self._reset()  # type: ignore[call-arg]
--> 628             data = self._next_data()
    629             self._num_yielded += 1
    630             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    669     def _next_data(self):
    670         index = self._next_index()  # may raise StopIteration
--> 671         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    672         if self._pin_memory:
    673             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     56                 data = self.dataset.__getitems__(possibly_batched_index)
     57             else:
---> 58                 data = [self.dataset[idx] for idx in possibly_batched_index]
     59         else:
     60             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     56                 data = self.dataset.__getitems__(possibly_batched_index)
     57             else:
---> 58                 data = [self.dataset[idx] for idx in possibly_batched_index]
     59         else:
     60             data = self.dataset[possibly_batched_index]

<ipython-input-191-4c84224a9983> in __getitem__(self, idx)
     27 
     28     # print(self.df.loc[idx, 'relative_path'])
---> 29     audio_file = self.df.loc[idx, "relative_path"]
     30     class_id = self.df.loc[idx, "dx"]
     31     # participant_id = self.df.loc[idx, 'adressfname']

/usr/local/lib/python3.8/dist-packages/pandas/core/indexing.py in __getitem__(self, key)
    923                 with suppress(KeyError, IndexError):
    924                     return self.obj._get_value(*key, takeable=self._takeable)
--> 925             return self._getitem_tuple(key)
    926         else:
    927             # we by definition only have the 0th axis

/usr/local/lib/python3.8/dist-packages/pandas/core/indexing.py in _getitem_tuple(self, tup)
   1098     def _getitem_tuple(self, tup: tuple):
   1099         with suppress(IndexingError):
-> 1100             return self._getitem_lowerdim(tup)
   1101 
   1102         # no multi-index, so validate all of the indexers

/usr/local/lib/python3.8/dist-packages/pandas/core/indexing.py in _getitem_lowerdim(self, tup)
    836                 # We don't need to check for tuples here because those are
    837                 #  caught by the _is_nested_tuple_indexer check above.
--> 838                 section = self._getitem_axis(key, axis=i)
    839 
    840                 # We should never have a scalar section here, because

/usr/local/lib/python3.8/dist-packages/pandas/core/indexing.py in _getitem_axis(self, key, axis)
   1162         # fall thru to straight lookup
   1163         self._validate_key(key, axis)
-> 1164         return self._get_label(key, axis=axis)
   1165 
   1166     def _get_slice_axis(self, slice_obj: slice, axis: int):

/usr/local/lib/python3.8/dist-packages/pandas/core/indexing.py in _get_label(self, label, axis)
   1111     def _get_label(self, label, axis: int):
   1112         # GH#5667 this will fail if the label is not present in the axis.
-> 1113         return self.obj.xs(label, axis=axis)
   1114 
   1115     def _handle_lowerdim_multi_index_axis0(self, tup: tuple):

/usr/local/lib/python3.8/dist-packages/pandas/core/generic.py in xs(self, key, axis, level, drop_level)
   3774                 raise TypeError(f"Expected label or tuple of labels, got {key}") from e
   3775         else:
-> 3776             loc = index.get_loc(key)
   3777 
   3778             if isinstance(loc, np.ndarray):

/usr/local/lib/python3.8/dist-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   3361                 return self._engine.get_loc(casted_key)
   3362             except KeyError as err:
-> 3363                 raise KeyError(key) from err
   3364 
   3365         if is_scalar(key) and isna(key) and not self.hasnans:

KeyError: 3170

Does anyone have any idea why one approach works without any issues, but the other does not?

The df.loc usage is wrong since you would need to use the pd.DataFrame index while you are trying to use the linear index created by the Dataset.
This code snippet shows the issue:

mydict = [{'a': 1, 'b': 2, 'c': 3, 'd': 4},
          {'a': 100, 'b': 200, 'c': 300, 'd': 400},
          {'a': 1000, 'b': 2000, 'c': 3000, 'd': 4000 }]
df = pd.DataFrame(mydict)

subset = df.iloc[[1, 2]]

# works
print(subset.loc[1])
print(df.loc[1])

# works
print(subset.loc[2])
print(df.loc[2])

# breaks
print(subset.loc[0])
# KeyError: 0
print(df.loc[0])

# works
print(subset.iloc[0])
print(df.iloc[0])

As you can see, subset.loc depends on the actual keys while iloc uses an index.
Change the indexing to iloc and see if this fixes the error.