Consider the following code:
class Cifar10Data(Dataset):
def __init__(self, data_dir, data_size = 0, transforms = None):
files = os.listdir(data_dir)
files = [os.path.join(data_dir,x) for x in files]
if data_size < 0 or data_size > len(files):
assert("Data size should be between 0 to number of files in the dataset")
if data_size == 0:
data_size = len(files)
self.data_size = data_size
self.files = np.random.choice(files, self.data_size)
self.transforms = transforms
def __len__(self):
return self.data_size
def __getitem__(self,index):
img_loc = self.files[index]
img = Image.open(img_loc)
img = preprocess(img)
#extract label from file_name
label_name = img_loc[:-4].split("_")[-1]
label = label_mapping[label_name]
img = img.astype(np.float32)
if self.transforms:
img = self.transforms(img)
return img, label
trainset = Cifar10Data(data_dir = "cifar/train/", transforms=None)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
The Dataloader helps to facilitate the batching of the data and so on. It expects a a Dataset
object as input.
Since I can specify the batch_size to be 128, the data_size attribute of the class Cifar10Data is not useful anymore. I do not quite know what happens. What does the data_size attribute do in this context?