The input to my model is images of varying length sequence, so I store them in a list.
My question is, when the batch_size of my dataloader is 8, can it fetch me 8 lists of images?
My dataset is defined as follows:
class mydataset(object):
def __init__(self, root, train_dir, label_dir, img_num, transform, training=True):
super(mydataset, self).__init__()
self.root = root
self.transform = transform
self.train_dir = self.root + train_dir
self.label_dir = self.root + label_dir
self.img_num = img_num
self.img_paths = []
for i in range(1, img_num + 1):
if not training:
i = i + 530
current_train_list = []
current_train_dir_path = self.train_dir + '/' + str(i) + '/'
img_name_list = os.listdir(current_train_dir_path)
for img_name in img_name_list:
current_train_img_path = current_train_dir_path + img_name
current_train_list.append(current_train_img_path)
label_img_path = self.label_dir + "/" +str(i)+".JPG"
self.img_paths.append({
"train_img_path_list": current_train_list,
"label_img_path": label_img_path
})
def __len__(self):
return self.img_num
def __getitem__(self, index):
paths = self.img_paths[index]
train_img_path_list = paths["train_img_path_list"]
train_img_list = []
for img_path in train_img_path_list:
train_img_list.append(Image.open(img_path).convert('RGB'))
label_img = Image.open(paths["label_img_path"]).convert('RGB')
seed = np.random.randint(0, 2 ** 16) # make a seed with numpy generator
random.seed(seed)
for i in range(len(train_img_list)):
random.seed(seed)
train_img_list[i] = self.transform(train_img_list[i])
random.seed(seed)
label_img = self.transform(label_img)
sample = {"train_img_list": train_img_list, "label_img": label_img}
return sample
data fetching:
for j, dataa in enumerate(trainloader):
exposure_sequence_list, Label = dataa['train_img_list'], dataa['label_img']
when batch_size = 8, label is of shape [batch_size, channel, height, weight]
and exposure_sequence is supposed to be 8 lists,
print(type(exposure_sequence_list))
print(len(exposure_sequence_list))
However, when I print them, type is class:List, and len is 7(one of the lists’ length)