Created a custom dataset and sampler, came across an indexing error

I created a custom dataset and a sampler and when I created created a loader using the instance of the sampler the indexing messed up after a while.

class CIFAR10_custom(Dataset):

    def __init__(self, transformations, path):
        super().__init__()
        self.cifar = trimmed_CIFAR10(n=80)
        self.transformations = transformations
        labels_file = 'pair_labels.npy'
        query_imgs_folder = 'query_imgs'
        key_imgs_folder = 'key_imgs'
        self.path = path
        with open(os.path.join(self.path, labels_file), 'rb') as f:
            np_arr = np.load(f)
        self.boolean_matrix = torch.tensor(np_arr)
        self.all_triplets = []
        self.totensor = transforms.ToTensor()
        for qimg_path in tqdm(glob(os.path.join(self.path, query_imgs_folder)+'/*.jpg')):
            stem = Path(qimg_path).stem
            _, idx_i, idx_j = stem.split("_")
            label = self.boolean_matrix[int(idx_i), int(idx_j)]
            kimg_path = os.path.join(os.path.join(self.path, key_imgs_folder), f'{stem}.jpg')
            if os.path.isfile(qimg_path) and os.path.isfile(kimg_path):
                self.all_triplets.append(
                    (qimg_path, kimg_path, label.item())
                )

    def __len__(self):
        return len(self.all_triplets)

    def __getitem__(self, idx):
        print("**************")
        print(idx)
        print("**************")
        q_img, k_img, label = self.all_triplets[idx]
        q_img, k_img = Image.open(q_img), Image.open(k_img)
        if transforms:
            q_img, k_img = self.transformations(q_img), self.transformations(k_img)
        else:
            q_img, k_img = self.totensor(q_img), self.totensor(k_img)
        return q_img, k_img, label
class CIFAR_sampler(Sampler):

    def __init__(self, data_source, batch_size, drop_last=False):
        super().__init__(data_source=data_source)
        self.dataset = data_source
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.batch = []
        self.batch_count = 0
        self.buffer_of_neg = []
        self.queue_of_pos = []

    def __iter__(self):

        for img_q, img_k, lbl in self.dataset:
            if lbl:
                self.queue_of_pos.append((img_q, img_k, lbl))

            elif (lbl == 0 and len(self.buffer_of_neg) < 6*self.batch_size):
                self.buffer_of_neg.append((img_q, img_k, lbl))

            if len(self.queue_of_pos) and len(self.buffer_of_neg) > self.batch_size:
                self.batch = self.buffer_of_neg[:self.batch_size-1]
                for i in range(self.batch_size - 1):
                    self.buffer_of_neg.pop(0)
                self.batch.insert(0, self.queue_of_pos.pop(0))
                self.batch_count += 1
                yield self.batch
                self.batch = []

            else:
                self.batch.append((img_q, img_k, lbl))
                if len(self.batch) == self.batch_size:
                    self.batch_count += 1
                    yield self.batch
                    self.batch = []

    def __len__(self):
        return self.batch_count

and the output is:

**************
0
**************
**************
1
**************
**************
2
**************
**************
3
**************
**************
4
**************
**************
5
**************
**************
6
**************
**************
7
**************
**************
8
**************
**************
9
**************
**************
10
**************
**************
11
**************
**************
12
**************
**************
13
**************
**************
14
**************
**************
15
**************
**************
16
**************
**************
17
**************
**************
18
**************
**************
19
**************
**************
20
**************
**************
21
**************
**************
22
**************
**************
23
**************
**************
24
**************
**************
25
**************
**************
26
**************
**************
27
**************
**************
28
**************
**************
29
**************
**************
30
**************
**************
31
**************
**************
[(tensor([[[0.1961, 0.1961, 0.1961,  ..., 0.1922, 0.1922, 0.1922],
         [0.1961, 0.1961, 0.1961,  ..., 0.1922, 0.1922, 0.1922],
         [0.1961, 0.1961, 0.1961,  ..., 0.1961, 0.1961, 0.1961],
         ...,
         [0.1961, 0.1961, 0.1922,  ..., 0.1882, 0.1882, 0.1882],
         [0.1961, 0.1961, 0.1922,  ..., 0.1843, 0.1843, 0.1843],
         [0.1961, 0.1961, 0.1922,  ..., 0.1843, 0.1843, 0.1843]],

        [[0.1765, 0.1765, 0.1765,  ..., 0.1765, 0.1765, 0.1765],
         [0.1765, 0.1765, 0.1765,  ..., 0.1765, 0.1765, 0.1765],
         [0.1765, 0.1765, 0.1765,  ..., 0.1765, 0.1765, 0.1765],
         ...,
         [0.1765, 0.1765, 0.1765,  ..., 0.1686, 0.1686, 0.1686],
         [0.1765, 0.1765, 0.1765,  ..., 0.1725, 0.1686, 0.1686],
         [0.1725, 0.1725, 0.1725,  ..., 0.1686, 0.1686, 0.1686]],

        [[0.2039, 0.2039, 0.2039,  ..., 0.2000, 0.2000, 0.2000],
         [0.2039, 0.2039, 0.2039,  ..., 0.2000, 0.2000, 0.2000],
         [0.2039, 0.2039, 0.2039,  ..., 0.2000, 0.2000, 0.2000],
         ...,
         [0.2000, 0.2000, 0.2000,  ..., 0.1961, 0.1961, 0.1961],
         [0.2000, 0.2000, 0.2000,  ..., 0.1922, 0.1922, 0.1922],
         [0.2000, 0.2000, 0.1961,  ..., 0.1922, 0.1922, 0.1922]]]), tensor([[[0.8353, 0.6745, 0.4706,  ..., 0.2980, 0.3020, 0.3059],
         [0.6902, 0.5098, 0.3451,  ..., 0.2784, 0.2824, 0.2706],
         [0.5294, 0.3647, 0.2549,  ..., 0.2784, 0.2784, 0.2471],
         ...,
         [0.6039, 0.5843, 0.5451,  ..., 0.8980, 0.9020, 0.8980],
         [0.7059, 0.7255, 0.6745,  ..., 0.9176, 0.8902, 0.8549],
         [0.7608, 0.8000, 0.7608,  ..., 0.8824, 0.8471, 0.8157]],

        [[0.8118, 0.6549, 0.4510,  ..., 0.2510, 0.2471, 0.2392],
         [0.6627, 0.4863, 0.3216,  ..., 0.2314, 0.2275, 0.2118],
         [0.4941, 0.3373, 0.2353,  ..., 0.2314, 0.2235, 0.1882],
         ...,
         [0.3647, 0.3412, 0.2980,  ..., 0.8549, 0.8627, 0.8667],
         [0.3647, 0.3843, 0.3490,  ..., 0.8745, 0.8549, 0.8235],
         [0.3294, 0.3804, 0.3647,  ..., 0.8392, 0.8078, 0.7843]],

        [[0.8196, 0.6745, 0.4745,  ..., 0.2235, 0.2157, 0.2078],
         [0.6588, 0.4941, 0.3373,  ..., 0.2157, 0.2078, 0.1843],
         [0.4863, 0.3373, 0.2392,  ..., 0.2235, 0.2118, 0.1686],
         ...,
         [0.3725, 0.3451, 0.2980,  ..., 0.9137, 0.9216, 0.9176],
         [0.3961, 0.4157, 0.3765,  ..., 0.9412, 0.9098, 0.8745],
         [0.3804, 0.4314, 0.4118,  ..., 0.9059, 0.8706, 0.8353]]]), 0.0), (tensor([[[0.1608, 0.1608, 0.1647,  ..., 0.1647, 0.1686, 0.1686],
         [0.1333, 0.1373, 0.1451,  ..., 0.1725, 0.1765, 0.1804],
         [0.1216, 0.1255, 0.1333,  ..., 0.1804, 0.1804, 0.1804],
..
..
..
..
..
       [[0.5373, 0.5412, 0.5412,  ..., 0.5059, 0.5059, 0.5059],
         [0.5333, 0.5373, 0.5412,  ..., 0.5059, 0.5059, 0.5059],
         [0.5255, 0.5294, 0.5412,  ..., 0.5098, 0.5059, 0.4980],
         ...,
         [0.1529, 0.1451, 0.1255,  ..., 0.2275, 0.2196, 0.2157],
         [0.2039, 0.1686, 0.1098,  ..., 0.2353, 0.2078, 0.1961],
         [0.2275, 0.1804, 0.0980,  ..., 0.2392, 0.2039, 0.1804]]]), 0.0)]
**************
Traceback (most recent call last):
  File "/Users/jayeshvasudeva/Development/for_fun/SPICE_implementatioon/dataset.py", line 143, in <module>
    for q, k, lbl in loader:
  File "/Users/jayeshvasudeva/miniconda3/envs/SPICE/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 652, in __next__
    data = self._next_data()
  File "/Users/jayeshvasudeva/miniconda3/envs/SPICE/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 692, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/Users/jayeshvasudeva/miniconda3/envs/SPICE/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/jayeshvasudeva/miniconda3/envs/SPICE/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/jayeshvasudeva/Development/for_fun/SPICE_implementatioon/dataset.py", line 76, in __getitem__
    q_img, k_img, label = self.all_triplets[idx]
TypeError: list indices must be integers or slices, not list

Why is the last index a tuple and not an integer like rest of the cases?

I don’t fully understand your custom sampler implementation as the sampler is responsible to create the indices passed to Dataset.__getitem__ while it seems you are directly iterating the dataset inside the sampler.
This would probably explain that you are passing the real samples into the __getitem__ instead of indices.
If you want to process the loaded samples, you might want to use a custom collate_fn instead.

I changed the code, as you pointed out, I was passing the data points itself, rather than the indices. But still it returns a list of indices rather a single index

class CIFAR_sampler(Sampler):

    def __init__(self, data_source, batch_size, drop_last=False):
        super().__init__(data_source=data_source)
        self.dataset = data_source
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.batch = []
        self.batch_count = 0
        self.buffer_of_neg = []
        self.queue_of_pos = []

    def __iter__(self):

        for idx, (img_q, img_k, lbl) in enumerate(self.dataset):
            if lbl:
                # self.queue_of_pos.append((img_q, img_k, lbl))
                self.queue_of_pos.append(idx)

            elif (lbl == 0 and len(self.buffer_of_neg) < 6*self.batch_size):
                # self.buffer_of_neg.append((img_q, img_k, lbl))
                self.buffer_of_neg.append(idx)

            if len(self.queue_of_pos) and len(self.buffer_of_neg) > self.batch_size:
                self.batch = self.buffer_of_neg[:self.batch_size-1]
                for i in range(self.batch_size - 1):
                    self.buffer_of_neg.pop(0)
                self.batch.insert(0, self.queue_of_pos.pop(0))
                self.batch_count += 1
                yield self.batch
                self.batch = []

            # else:
            #     self.batch.append((img_q, img_k, lbl))
            #     self.batch.append(idx)
            #     if len(self.batch) == self.batch_size:
            #         self.batch_count += 1
            #         yield self.batch
            #         self.batch = []

    def __len__(self):
        return self.batch_count

Output:

**************
0
**************
**************
1
**************
**************
2
**************
**************
3
**************
**************
4
**************
**************
5
**************
**************
6
**************
**************
7
**************
**************
8
**************
**************
9
**************
**************
10
**************
**************
11
**************
**************
12
**************
**************
13
**************
**************
14
**************
**************
15
**************
**************
16
**************
**************
17
**************
**************
18
**************
**************
19
**************
**************
20
**************
**************
21
**************
**************
22
**************
**************
23
**************
**************
24
**************
**************
25
**************
**************
26
**************
**************
27
**************
**************
28
**************
**************
29
**************
**************
30
**************
**************
31
**************
**************
32
**************
**************
33
**************
**************
34
**************
**************
35
**************
**************
36
**************
**************
37
**************
**************
[9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 29, 31, 32, 33, 34]
**************
Traceback (most recent call last):
  File "/Users/jayeshvasudeva/Development/for_fun/SPICE_implementatioon/dataset.py", line 147, in <module>
    for q, k, lbl in loader:
  File "/Users/jayeshvasudeva/miniconda3/envs/SPICE/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 652, in __next__
    data = self._next_data()
  File "/Users/jayeshvasudeva/miniconda3/envs/SPICE/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 692, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/Users/jayeshvasudeva/miniconda3/envs/SPICE/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/jayeshvasudeva/miniconda3/envs/SPICE/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/jayeshvasudeva/Development/for_fun/SPICE_implementatioon/dataset.py", line 76, in __getitem__
    q_img, k_img, label = self.all_triplets[idx]
TypeError: list indices must be integers or slices, not list

You are appending all indices (from the dataset loop) and are passing them to the __getitem__ via yield self.batch.
Take a look at some sampler implementations here to see how they are passing the indices to the Dataset.