Torch.utils.data.random_split using all RAM

I have a custom dataset that is lazy loaded with 170 million samples. When I call the method torch.utils.data.random_split to split between train and test, it takes a very long time. I can see that the memory utilization goes all the way to 100% and OS starts swapping, I have 32gb of physical memory. It is my understanding that torch.utils.data.random_split does not loads the data into memory. So why is random_split taking so much memory for this operation?

random_split is creating Subsets as seen here which wrap the Dataset and call it with the passed indices as seen here. I don’t see where the Dataset is cloned unless the memory usage increase is not caused by random_split but e.g. the workers in the DataLoader.

The dataset is not responsible for that. It only contains references to files, and only read line by line once and only when the getitem method is called.

If I dont call the random_split method and just iterate over the dataset, I see constant memory utilization without any observable increase.

Can it be that just holding the list of the 170 M indices is causing this memory spike?

I doubt it, as 170,000,000 indices would use ~1.2GB:

lengths = [100000000, 70000000]
indices = torch.randperm(sum(lengths)).tolist()

print(sys.getsizeof(indices) / 1024**3)
# 1.266598753631115

If your memory usage is already high, these additional 1.2GB could indeed force your system to offload into the swap.