Why does my datapipe based dataloader take up so much memory?

I am training on 2 GPUs, each of which uses a dataloader with 16 workers. Due to the nature of my training objective (contrastive loss), each GPU must receive different samples, in order to prevent a batch with duplicate samples.

I’m noticing that my dataloader takes up a very large amount of memory (~64 GB & then my computer crashes), and I don’t know why.

For context, my images are size 240x240, and my batch size is 512. Thus, I expect my 2 dataloaders to take up 2 * 16 * 3 * 240 * 240 * 512 = 2831155200 bytes = 2.831 GB of memory at peak usage.

Here’s what the code looks like:

# friendly wrapper over torch.distributed
import composer.utils.dist as dist
def load_parquet(path):
    table = pq.read_table(path)
    imgs = table["img"]
    imgs = [load_tensor(img.as_py()) for img in imgs]
    del table

    shard_idx = dist.get_global_rank()
    worker_id = get_worker_info().id
    print(f"Loaded {len(imgs)} imgs (shape={imgs[0].shape}) from {path} with shard={shard_idx} & worker_id={worker_id}")
    return imgs

def sharded_parquet_dataloader(dir, **kwargs):
    print(f"Creating dl with kwargs: {kwargs}")
    datapipe = sharded_local_datapipe(dir)
    datapipe = datapipe.sharding_filter()
    datapipe = datapipe.map(load_parquet)
    datapipe = datapipe.flatmap()
    dl = DataLoader(
        datapipe, **kwargs
    )
    return dl

where the sharder_local_datapipe is a tool to ensure that each GPU gets unique samples:

def sharded_local_datapipe(dir: Path):
    world_size = dist.get_world_size()
    shard_idx = dist.get_global_rank()
    total_files = count_files(dir)
    # All data loaders should have the same size
    shard_size = total_files // world_size

    # For this to work, we assume the Lister is deterministic
    pipe = dp.iter.FSSpecFileLister(str(dir))
    pipe = pipe.enumerate()

    def in_shard(idx_and_name):
        idx, _ = idx_and_name
        return (idx % world_size) == shard_idx

    def take_last(args):
        return args[-1]

    pipe = pipe.filter(in_shard)
    pipe = pipe.map(take_last)
    return pipe.header(shard_size)

Does anyone know what might be going on here? Print statements indicate that not that many images are being loaded, so I’m trying to figure out where the extra memory is coming from:

contrastive_train-contrastive_train-1  | Train dataset: /data/train
contrastive_train-contrastive_train-1  | Val dataset: None
contrastive_train-contrastive_train-1  | t_warmup: 100ba
contrastive_train-contrastive_train-1  | t_max: 3600ba
contrastive_train-contrastive_train-1  | Creating dl with kwargs: {'batch_size': 512, 'num_workers': 16, 'pin_memory': True, 'drop_last': True}
contrastive_train-contrastive_train-1  | Checking dataloader...
contrastive_train-contrastive_train-1  | Loaded 243 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/102.parquet with shard=0 & worker_id=4
contrastive_train-contrastive_train-1  | flat-mapping: 243, shard=0, worker_id=4
contrastive_train-contrastive_train-1  | Loaded 242 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/97.parquet with shard=0 & worker_id=3
contrastive_train-contrastive_train-1  | Loaded 228 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/143.parquet with shard=0 & worker_id=5
contrastive_train-contrastive_train-1  | Loaded 228 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/135.parquet with shard=0 & worker_id=12
contrastive_train-contrastive_train-1  | Loaded 228 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/124.parquet with shard=0 & worker_id=4
contrastive_train-contrastive_train-1  | flat-mapping: 228, shard=0, worker_id=4
contrastive_train-contrastive_train-1  | Loaded 249 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/123.parquet with shard=0 & worker_id=14
contrastive_train-contrastive_train-1  | Loaded 254 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/117.parquet with shard=0 & worker_id=15
contrastive_train-contrastive_train-1  | flat-mapping: 254, shard=0, worker_id=15
contrastive_train-contrastive_train-1  | Loaded 200 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/129.parquet with shard=0 & worker_id=3
contrastive_train-contrastive_train-1  | Loaded 214 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/110.parquet with shard=0 & worker_id=4
contrastive_train-contrastive_train-1  | Loaded 258 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/116.parquet with shard=0 & worker_id=3
contrastive_train-contrastive_train-1  | Loaded 1000 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/48.parquet with shard=0 & worker_id=10
contrastive_train-contrastive_train-1  | Loaded 1000 imgs (shape=torch.Size([1, 3, 240, 240])) from /data/train/45.parquet with shard=0 & worker_id=0
^CGracefully stopping... (press Ctrl+C again to force)

Wondering what would be the memory usage without using pin_memory?