Optimizing Data Loading for Large-Scale Distributed Training

I’m currently working on training a GPT model using The Pile dataset on a single node with 8 A100 GPUs. This dataset is quite large, weighing in at around 950 GB. To achieve efficient distributed training, I’m leveraging torchrun for its ease of use and seamless integration. However, I’ve encountered an issue where the data loading process seems to be triggered 8 times in parallel, which, I suspect, leads to excessive disk read overhead and consequently, a slow loading process.

When running the training on a single GPU, the dataset takes about 10 minutes to load. However, with all 8 GPUs in action, the loading time increases significantly, taking over an hour.

Here’s the command I’m using to initiate the training:

torchrun -m --nproc_per_node=8 gpt train

And this is the general structure of the relevant training code:

def train(cfg: DictConfig):

    # Setup data and dataloader
    data_module = prepare_data_module(**cfg.data)

    train_dataloader = data_module.train_dataloader()
    val_dataloader = data_module.val_dataloader()

    # Extract tokenizer from datamodule
    tokenizer = data_module.tokenizer

    # Setup model and optimizer
    model = GPT(**cfg.model)

    # Setup data collator
    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    train_args = TrainingArguments(**cfg.train_args)

    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        train_args=args,
        data_collator=data_collator,
        train_dataset=train_dataloader,
        eval_dataset=val_dataloader,
    )

    trainer.train()

I am seeking insights from the community on potential optimizations to resolve this data loading bottleneck. Are there any recommendations or best practices for efficiently handling data loading and distribution when training large-scale models like GPT with PyTorch?

Looking forward to your thoughts and suggestions!

1 Like

cc @VitalyFedyunin for DataLoader questions

MosaicML’s streaming API tries to address this exact usecase: Streaming
Since it’s a drop-in replacement for Dataset it might be appropriate for your program

1 Like

Hi @suraj.pt!

No way, you replied to my post! I learned about this stuff though your videos. It’s amazing to get advice straight from you, thanks!

I looked into Streaming, but I’m not sure if its useful in my case.

It’s documentation states:

StreamingDataset helps to make training on large datasets from cloud storage as fast, cheap, and scalable as possible.

I’m reading the data locally, so if I’m not mistaken this library would not work in my case, right?

1 Like

I just realized it has a local option. I’ll look into it further.

Thanks a lot @suraj.pt!

1 Like

What does the dataset initialisation do? I assume you don’t preload the dataset, because it would mean loading 8 x num_workers x 950GB in memory. If you only generate some metadata for the loader, you could cache this data in a temporary file and load the cache in the dataset.