I have the following dataloader below:
def load_dataset(size_batch, size):
data_path = "/home/bledc/dataset/test_set/crops_BSD"
transformations = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor()
])
train_dataset = datasets.ImageFolder(
root=data_path,
transform=transformations
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=size_batch,
shuffle=True,
num_workers=0,
drop_last=True
)
return train_loader
I iterate through it in my training loop with the following:
data_loader = load_dataset(batch_size, width)
for data in data_loader:
model.zero_grad()
optimizer.zero_grad()
img, _ = data
img = img.to(device)
Can someone explain to me what the benefit would be of writing the load_dataset() function into a class? The reason for this is that I have been using the above template for dataloading from code I found online but it seems that most codebases use class LoaderName(Dataset)
followed by definiting initial conditions and super().
Thank you.