I have a custom dataset that is lazy loaded with 170 million samples. When I call the method torch.utils.data.random_split
to split between train and test, it takes a very long time. I can see that the memory utilization goes all the way to 100% and OS starts swapping, I have 32gb of physical memory. It is my understanding that torch.utils.data.random_split
does not loads the data into memory. So why is random_split
taking so much memory for this operation?
random_split
is creating Subset
s as seen here which wrap the Dataset
and call it with the passed indices as seen here. I don’t see where the Dataset
is cloned unless the memory usage increase is not caused by random_split
but e.g. the workers in the DataLoader
.
The dataset is not responsible for that. It only contains references to files, and only read line by line once and only when the getitem method is called.
If I dont call the random_split
method and just iterate over the dataset, I see constant memory utilization without any observable increase.
Can it be that just holding the list of the 170 M indices is causing this memory spike?
I doubt it, as 170,000,000 indices would use ~1.2GB:
lengths = [100000000, 70000000]
indices = torch.randperm(sum(lengths)).tolist()
print(sys.getsizeof(indices) / 1024**3)
# 1.266598753631115
If your memory usage is already high, these additional 1.2GB could indeed force your system to offload into the swap.