I think the best way to explain the DataLoader
is to start with the Dataset
.
Using a Dataset
you can load and handle a dataset as the name suggests.
Basically, you just need to define three functions in a Dataset
.
In the __init__
method you can pass some arguments like a pre-loaded dataset or file paths.
You can also pass transformations to it, but let’s stay with the basics.
Let’s assume we have a very large dataset, so we pass file paths to the Dataset
.
def __init__(self, image_paths):
self.image_paths = image_paths
Next we define the __getitem__
function. This function gets an index
argument to load the sample at the current index. Here we can use index
to get the file path for this index and load the image lazily.
def __getitem__(self, index):
image = PIL.Image.open(self.image_paths[index])
# Transform it to Tensor
x = torchvision.transforms.functional.to_tensor(image)
return x
The last function is __len__
, which just returns the length of the dataset.
def __len__(self):
return len(self.image_paths)
It is used e.g. by the DataLoader
to know, how many samples are available in the Dataset
.
Until now we didn’t have to think about batching, i.e. in other frameworks we would have to implement something like a for-loop in a generator and load batches of samples.
Also, we didn’t have to think about shuffling etc.
The Dataset
would look like this:
class MyDataset(Dataset):
def __init__(self, image_paths):
self.image_paths = image_paths
def __getitem__(self, index):
image = PIL.Image.open(self.image_paths[index])
# Transform it to Tensor
x = torchvision.transforms.functional.to_tensor(image)
return x
def __len__(self):
return len(self.image_paths)
A DataLoader
basically wraps a Dataset
and gives us some options, e.g. setting the batch size or activating shuffling. It’s also able to use multi-processing to speed up the loading.
loader = DataLoader(dataset,
batch_size=64,
num_workers=8, # set the number of workers for multi-processing
shuffle=True)
So just by wrapping the Dataset
we got “automatic” batching, multi-processing and shuffling.
It has some additional arguments, which we will skip for now.
Now we can use this DataLoader
in a for-loop and in every iteration the loader will return a whole batch from the Dataset
by calling its __getitem__
function. The loop is defined by the __len__
of the Dataset
.
Since we just defined image paths without a target, we will just get the images back.
The last batch can be smaller than batch_size
, if the length of the dataset is not divisible by the batch size without a remainder.
If you don’t want to have this behavior, you can disable it with drop_last=True
in the DataLoader
.
Since the DataLoader
loads the samples using multi-processing, we can train the model while the next batch is being loaded.
for batch_idx, data in enumerate(loader):
print('batch {}, shape {}'.format(batch_idx, data.shape))
# Your training routine
I hope this helps!