class ImageDataset(data.Dataset):
def __init__(self, root_dir, num_augments=2, transform=None):
self.root_dir = root_dir
self.img_names = os.listdir(root_dir)[::600]
self.num_augments = num_augments
self.transform = transform
def __getitem__(self, index):
output = []
img = Image.open(self.root_dir + '/' + self.img_names[index]).convert('RGB')
for i in range(self.num_augments):
if self.transform is not None:
img_transform = self.transform(img)
output.append(img_transform)
output = torch.stack(output, axis=0)
return output
def __len__(self):
return len(self.img_names)
I am calling my dataset in the form
for i,images in enumerate(train_dataset):
I am expecting images to be of size [batch_size, num_augments, 3, height, width], but I am getting [1, num_augments, 3, height, width] regardless of my batch size.