Why does my datapipe based dataloader take up so much memory?

I am training on 2 GPUs, each of which uses a dataloader with 16 workers. Due to the nature of my training objective (contrastive loss), each GPU must receive different samples, in order to prevent a batch with duplicate samples.

I’m noticing that my dataloader takes up a very large amount of memory (~64 GB & then my computer crashes), and I don’t know why.

For context, my images are size 240x240, and my batch size is 512. Thus, I expect my 2 dataloaders to take up 2 * 16 * 3 * 240 * 240 * 512 = 2831155200 bytes = 2.831 GB of memory at peak usage.

Here’s what the code looks like:

# friendly wrapper over torch.distributed
import composer.utils.dist as dist
def load_parquet(path):
    table = pq.read_table(path)
    imgs = table["img"]
    imgs = [load_tensor(img.as_py()) for img in imgs]
    del table

    shard_idx = dist.get_global_rank()
    worker_id = get_worker_info().id
    print(f"Loaded {len(imgs)} imgs (shape={imgs[0].shape}) from {path} with shard={shard_idx} & worker_id={worker_id}")
    return imgs

def sharded_parquet_dataloader(dir, **kwargs):
    print(f"Creating dl with kwargs: {kwargs}")
    datapipe = sharded_local_datapipe(dir)
    datapipe = datapipe.sharding_filter()
    datapipe = datapipe.map(load_parquet)
    datapipe = datapipe.flatmap()
    dl = DataLoader(
        datapipe, **kwargs
    )
    return dl

where the sharder_local_datapipe is a tool to ensure that each GPU gets unique samples:

def sharded_local_datapipe(dir: Path):
    world_size = dist.get_world_size()
    shard_idx = dist.get_global_rank()
    total_files = count_files(dir)
    # All data loaders should have the same size
    shard_size = total_files // world_size

    # For this to work, we assume the Lister is deterministic
    pipe = dp.iter.FSSpecFileLister(str(dir))
    pipe = pipe.enumerate()

    def in_shard(idx_and_name):
        idx, _ = idx_and_name
        return (idx % world_size) == shard_idx

    def take_last(args):
        return args[-1]

    pipe = pipe.filter(in_shard)
    pipe = pipe.map(take_last)
    return pipe.header(shard_size)

Does anyone know what might be going on here? Print statements indicate that not that many images are being loaded, so I’m trying to figure out where the extra memory is coming from:

contrastive_train-contrastive_train-1  | Train dataset: /data/train
contrastive_train-contrastive_train-1  | Val dataset: None
contrastive_train-contrastive_train-1  | t_warmup: 100ba
contrastive_train-contrastive_train-1  | t_max: 3600ba
contrastive_train-contrastive_train-1  | Creating dl with kwargs: {'batch_size': 512, 'num_workers': 16, 'pin_memory': True, 'drop_last': True}
contrastive_train-contrastive_train-1  | Checking dataloader...
contrastive_train-contrastive_train-1  | Loaded 243 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/102.parquet with shard=0 & worker_id=4
contrastive_train-contrastive_train-1  | flat-mapping: 243, shard=0, worker_id=4
contrastive_train-contrastive_train-1  | Loaded 242 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/97.parquet with shard=0 & worker_id=3
contrastive_train-contrastive_train-1  | Loaded 228 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/143.parquet with shard=0 & worker_id=5
contrastive_train-contrastive_train-1  | Loaded 228 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/135.parquet with shard=0 & worker_id=12
contrastive_train-contrastive_train-1  | Loaded 228 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/124.parquet with shard=0 & worker_id=4
contrastive_train-contrastive_train-1  | flat-mapping: 228, shard=0, worker_id=4
contrastive_train-contrastive_train-1  | Loaded 249 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/123.parquet with shard=0 & worker_id=14
contrastive_train-contrastive_train-1  | Loaded 254 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/117.parquet with shard=0 & worker_id=15
contrastive_train-contrastive_train-1  | flat-mapping: 254, shard=0, worker_id=15
contrastive_train-contrastive_train-1  | Loaded 200 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/129.parquet with shard=0 & worker_id=3
contrastive_train-contrastive_train-1  | Loaded 214 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/110.parquet with shard=0 & worker_id=4
contrastive_train-contrastive_train-1  | Loaded 258 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/116.parquet with shard=0 & worker_id=3
contrastive_train-contrastive_train-1  | Loaded 1000 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/48.parquet with shard=0 & worker_id=10
contrastive_train-contrastive_train-1  | Loaded 1000 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/45.parquet with shard=0 & worker_id=0
^CGracefully stopping... (press Ctrl+C again to force)

Wondering what would be the memory usage without using pin_memory?

I’ll take a look. AFAIK, the memory usage is roughly the same.
What would be a good way to debug something like “too much memory usage”?

Is there anything worth using here besides tracemalloc?

Also, is there a way to check how much memory is being pinned?

Without any dataloaders, the memory usage is 22 GB.
With a pinned_memory dataloader, the usage is 49 GB.
Without pinned_memory, the memory usage is 40.5 GB.

While not pinning memory, the average usage is still a few orders of magnitude greater than I expected.

cc: @nivek to see if you have any recommendation on the tools to monitor emory usage.

I think GitHub - janestreet/magic-trace: magic-trace collects and displays high-resolution traces of what a process is doing is a good tool to do profiling.

BTW, have you ever tried to use this DataPipe to load from parquet?

I believe this is kind of streaming way to read parquet file, which should reduce memory footprint in theory.

scalene is another option. It can give you memory usage line by line.