KeyError when for loop dataloader - why?

when i try to use k-fold in pytorch:

for fold, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(dataset)))):

    print('Fold {}'.format(fold + 1))

    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(val_idx)
    
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler,num_workers=0)
    test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler,num_workers=0)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    train_loader = DeviceDataLoader(train_loader, device)
    test_loader = DeviceDataLoader(test_loader, device)


    for batch in train_loader:
        pass

Error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/tmp/ipykernel_34/2466581335.py in <module>
     16 
     17 
---> 18     for batch in train_loader:
     19         pass
     20 

/tmp/ipykernel_34/3232548952.py in __iter__(self)
     22     def __iter__(self):
     23         """ Yield a batch of data after moving it to device"""
---> 24         for b in self.dl:
     25             yield to_device(b,self.device)
     26 

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    519             if self._sampler_iter is None:
    520                 self._reset()
--> 521             data = self._next_data()
    522             self._num_yielded += 1
    523             if self._dataset_kind == _DatasetKind.Iterable and \

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    559     def _next_data(self):
    560         index = self._next_index()  # may raise StopIteration
--> 561         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    562         if self._pin_memory:
    563             data = _utils.pin_memory.pin_memory(data)

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataset.py in __getitem__(self, idx)
    255         else:
    256             sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
--> 257         return self.datasets[dataset_idx][sample_idx]
    258 
    259     @property

KeyError: 0

And my dataset is:

train_data_df = DataFrame({"image": all_data["train"]["images"], "count": [item["lymphocyte"]for item in all_data["train"]["count"]]}).reset_index(drop=True)
dev_data_df = DataFrame({"image": all_data["dev"]["images"], "count": [item["lymphocyte"]for item in all_data["dev"]["count"]]}).reset_index(drop=True)

print(train_data_df)

class CONICDataset(Dataset): 
    def __init__(self, df, transform=None, target_transform=None): 
        self.df = df
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        img = self.df['image'][index]
        label = torch.tensor(self.df['count'][index])
        
        if self.transform is not None:
            img = self.transform(img)   

        return img, label

    def __len__(self):
        return len(self.df) 
 
train_data = CONICDataset(train_data_df, 
                    transform["train"])

val_data =  CONICDataset(dev_data_df,
                    transform["dev"])
dataset = ConcatDataset([train_data, dev_data])

Any thoughts on what could cause this? Or how I should approach this error?

I don’t know what DeviceDatLoader is so could you check if the code works fine without it?
If not, could you check, if dataset[0] returns a valid sample?