I created a custom dataset and a sampler and when I created created a loader using the instance of the sampler the indexing messed up after a while.
class CIFAR10_custom(Dataset):
def __init__(self, transformations, path):
super().__init__()
self.cifar = trimmed_CIFAR10(n=80)
self.transformations = transformations
labels_file = 'pair_labels.npy'
query_imgs_folder = 'query_imgs'
key_imgs_folder = 'key_imgs'
self.path = path
with open(os.path.join(self.path, labels_file), 'rb') as f:
np_arr = np.load(f)
self.boolean_matrix = torch.tensor(np_arr)
self.all_triplets = []
self.totensor = transforms.ToTensor()
for qimg_path in tqdm(glob(os.path.join(self.path, query_imgs_folder)+'/*.jpg')):
stem = Path(qimg_path).stem
_, idx_i, idx_j = stem.split("_")
label = self.boolean_matrix[int(idx_i), int(idx_j)]
kimg_path = os.path.join(os.path.join(self.path, key_imgs_folder), f'{stem}.jpg')
if os.path.isfile(qimg_path) and os.path.isfile(kimg_path):
self.all_triplets.append(
(qimg_path, kimg_path, label.item())
)
def __len__(self):
return len(self.all_triplets)
def __getitem__(self, idx):
print("**************")
print(idx)
print("**************")
q_img, k_img, label = self.all_triplets[idx]
q_img, k_img = Image.open(q_img), Image.open(k_img)
if transforms:
q_img, k_img = self.transformations(q_img), self.transformations(k_img)
else:
q_img, k_img = self.totensor(q_img), self.totensor(k_img)
return q_img, k_img, label
class CIFAR_sampler(Sampler):
def __init__(self, data_source, batch_size, drop_last=False):
super().__init__(data_source=data_source)
self.dataset = data_source
self.batch_size = batch_size
self.drop_last = drop_last
self.batch = []
self.batch_count = 0
self.buffer_of_neg = []
self.queue_of_pos = []
def __iter__(self):
for img_q, img_k, lbl in self.dataset:
if lbl:
self.queue_of_pos.append((img_q, img_k, lbl))
elif (lbl == 0 and len(self.buffer_of_neg) < 6*self.batch_size):
self.buffer_of_neg.append((img_q, img_k, lbl))
if len(self.queue_of_pos) and len(self.buffer_of_neg) > self.batch_size:
self.batch = self.buffer_of_neg[:self.batch_size-1]
for i in range(self.batch_size - 1):
self.buffer_of_neg.pop(0)
self.batch.insert(0, self.queue_of_pos.pop(0))
self.batch_count += 1
yield self.batch
self.batch = []
else:
self.batch.append((img_q, img_k, lbl))
if len(self.batch) == self.batch_size:
self.batch_count += 1
yield self.batch
self.batch = []
def __len__(self):
return self.batch_count
and the output is:
**************
0
**************
**************
1
**************
**************
2
**************
**************
3
**************
**************
4
**************
**************
5
**************
**************
6
**************
**************
7
**************
**************
8
**************
**************
9
**************
**************
10
**************
**************
11
**************
**************
12
**************
**************
13
**************
**************
14
**************
**************
15
**************
**************
16
**************
**************
17
**************
**************
18
**************
**************
19
**************
**************
20
**************
**************
21
**************
**************
22
**************
**************
23
**************
**************
24
**************
**************
25
**************
**************
26
**************
**************
27
**************
**************
28
**************
**************
29
**************
**************
30
**************
**************
31
**************
**************
[(tensor([[[0.1961, 0.1961, 0.1961, ..., 0.1922, 0.1922, 0.1922],
[0.1961, 0.1961, 0.1961, ..., 0.1922, 0.1922, 0.1922],
[0.1961, 0.1961, 0.1961, ..., 0.1961, 0.1961, 0.1961],
...,
[0.1961, 0.1961, 0.1922, ..., 0.1882, 0.1882, 0.1882],
[0.1961, 0.1961, 0.1922, ..., 0.1843, 0.1843, 0.1843],
[0.1961, 0.1961, 0.1922, ..., 0.1843, 0.1843, 0.1843]],
[[0.1765, 0.1765, 0.1765, ..., 0.1765, 0.1765, 0.1765],
[0.1765, 0.1765, 0.1765, ..., 0.1765, 0.1765, 0.1765],
[0.1765, 0.1765, 0.1765, ..., 0.1765, 0.1765, 0.1765],
...,
[0.1765, 0.1765, 0.1765, ..., 0.1686, 0.1686, 0.1686],
[0.1765, 0.1765, 0.1765, ..., 0.1725, 0.1686, 0.1686],
[0.1725, 0.1725, 0.1725, ..., 0.1686, 0.1686, 0.1686]],
[[0.2039, 0.2039, 0.2039, ..., 0.2000, 0.2000, 0.2000],
[0.2039, 0.2039, 0.2039, ..., 0.2000, 0.2000, 0.2000],
[0.2039, 0.2039, 0.2039, ..., 0.2000, 0.2000, 0.2000],
...,
[0.2000, 0.2000, 0.2000, ..., 0.1961, 0.1961, 0.1961],
[0.2000, 0.2000, 0.2000, ..., 0.1922, 0.1922, 0.1922],
[0.2000, 0.2000, 0.1961, ..., 0.1922, 0.1922, 0.1922]]]), tensor([[[0.8353, 0.6745, 0.4706, ..., 0.2980, 0.3020, 0.3059],
[0.6902, 0.5098, 0.3451, ..., 0.2784, 0.2824, 0.2706],
[0.5294, 0.3647, 0.2549, ..., 0.2784, 0.2784, 0.2471],
...,
[0.6039, 0.5843, 0.5451, ..., 0.8980, 0.9020, 0.8980],
[0.7059, 0.7255, 0.6745, ..., 0.9176, 0.8902, 0.8549],
[0.7608, 0.8000, 0.7608, ..., 0.8824, 0.8471, 0.8157]],
[[0.8118, 0.6549, 0.4510, ..., 0.2510, 0.2471, 0.2392],
[0.6627, 0.4863, 0.3216, ..., 0.2314, 0.2275, 0.2118],
[0.4941, 0.3373, 0.2353, ..., 0.2314, 0.2235, 0.1882],
...,
[0.3647, 0.3412, 0.2980, ..., 0.8549, 0.8627, 0.8667],
[0.3647, 0.3843, 0.3490, ..., 0.8745, 0.8549, 0.8235],
[0.3294, 0.3804, 0.3647, ..., 0.8392, 0.8078, 0.7843]],
[[0.8196, 0.6745, 0.4745, ..., 0.2235, 0.2157, 0.2078],
[0.6588, 0.4941, 0.3373, ..., 0.2157, 0.2078, 0.1843],
[0.4863, 0.3373, 0.2392, ..., 0.2235, 0.2118, 0.1686],
...,
[0.3725, 0.3451, 0.2980, ..., 0.9137, 0.9216, 0.9176],
[0.3961, 0.4157, 0.3765, ..., 0.9412, 0.9098, 0.8745],
[0.3804, 0.4314, 0.4118, ..., 0.9059, 0.8706, 0.8353]]]), 0.0), (tensor([[[0.1608, 0.1608, 0.1647, ..., 0.1647, 0.1686, 0.1686],
[0.1333, 0.1373, 0.1451, ..., 0.1725, 0.1765, 0.1804],
[0.1216, 0.1255, 0.1333, ..., 0.1804, 0.1804, 0.1804],
..
..
..
..
..
[[0.5373, 0.5412, 0.5412, ..., 0.5059, 0.5059, 0.5059],
[0.5333, 0.5373, 0.5412, ..., 0.5059, 0.5059, 0.5059],
[0.5255, 0.5294, 0.5412, ..., 0.5098, 0.5059, 0.4980],
...,
[0.1529, 0.1451, 0.1255, ..., 0.2275, 0.2196, 0.2157],
[0.2039, 0.1686, 0.1098, ..., 0.2353, 0.2078, 0.1961],
[0.2275, 0.1804, 0.0980, ..., 0.2392, 0.2039, 0.1804]]]), 0.0)]
**************
Traceback (most recent call last):
File "/Users/jayeshvasudeva/Development/for_fun/SPICE_implementatioon/dataset.py", line 143, in <module>
for q, k, lbl in loader:
File "/Users/jayeshvasudeva/miniconda3/envs/SPICE/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 652, in __next__
data = self._next_data()
File "/Users/jayeshvasudeva/miniconda3/envs/SPICE/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 692, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/Users/jayeshvasudeva/miniconda3/envs/SPICE/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/Users/jayeshvasudeva/miniconda3/envs/SPICE/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/Users/jayeshvasudeva/Development/for_fun/SPICE_implementatioon/dataset.py", line 76, in __getitem__
q_img, k_img, label = self.all_triplets[idx]
TypeError: list indices must be integers or slices, not list
Why is the last index a tuple and not an integer like rest of the cases?