Datapipe optimisation


I’m working with tar file streamed from S3 bucket. I find that the current Datapipe is slow at fetching all the data. The delay grows with more URLs (data_dir).

I’ve tried moving the sharding_filter before load_from_tar, but this results in undesired behaviour, such that when more GPUs are added, the sharding fails, and each GPU receives the full data instead of batch/GPUs

datapipe = torchdata.datapipes.iter.IterableWrapper(data_dir)\
			.load_from_tar() \
			.batch(2) \
			.map(self.to_sampels) \
			.batch(self.batch_size) \

service = [DistributedReadingService(),
reading_service = SequentialReadingService(*service)
dataloader =  DataLoader2(datapipe, reading_service=reading_service)

Try this config, lets see

datapipe = (
    .enumerate()  # Assign an index to each sample
    .sharding_filter()  # Filter samples based on their index and the GPU rank

service = [
reading_service = SequentialReadingService(*service)
dataloader = DataLoader2(datapipe, reading_service=reading_service)

Tested your solution of adding .enumerate() but this doubled the loading time.

The process that is taking long is the load_from_tar because this is done on the main node i assume and not split across workers.

Try creating a custom datapipe that fetches tar files in parallel, extracts the content, and yields the extracted files.

I don’t think this will work. The Datapipe is loading the full dataset before sharding_filter is applied. ideally load_from_tar should be performed after sharding

This is likely a bug or an order of operation issue.

I also tried this approach, but it is still slow.

Hi @knoriy

You almost certainly want to use .sharding_filter() immediately after your first shuffle(). Otherwise, you will be duplicating work across workers (such as open_files_by_fsspec, load_from_tar, batch(2)).

but this results in undesired behaviour, such that when more GPUs are added, the sharding fails, and each GPU receives the full data instead of batch/GPUs

As for this, can you further elaborate? Perhaps provide something reproducible?

For example, it would be useful to check torch.distributed.get_world_size() and torch.distributed.get_rank() as those will be used for sharding by DistributedReadingService within DataLoader2.

Hi @nivek, Thank you for your comment.


I agree with your statement. Unfortunately, this results in unexpected behaviour; these are:

  1. when .shardfilter is above each batch(2), GPU received the full dataset and not batch/gpu, I fixed this by changing the order of the reading service ([MPRS, DRS] >> [DRS, MPRS]) and moving .sharding_filter below .batch(2).
  2. Using MPRS and them DRS resulted in getting can't pickle: ExFileObject


I fixed this by reordering the reading services. What is the significance of the order by which reading services are passed to dataloader2?

Here is a simple working example: using pytorch==2.1.0 and pytorch_lightning=2.0. Where "path/to/*.tar" (line 21) is a path to an s3 bucket holding tar files containing *.wav and its corresponding*.json label.

import io
import soundfile
import json

import torch
import torch.nn as nn
import torchdata
from torchdata.dataloader2 import DataLoader2, DistributedReadingService, MultiProcessingReadingService, SequentialReadingService 

import pytorch_lightning as pl

class PLModel(pl.LightningModule):
	def __init__(self):
		self.layer = nn.Linear(32, 2)

	def forward(self, x):
		return x

	def training_step(self, batch, batch_idx):
		out = self(batch)

	def configure_optimizers(self):
		return torch.optim.SGD(self.layer.parameters(), lr=0.1)

class PL_datamodule(pl.LightningDataModule):
	def __init__(self):
		self.urls = ["path/to/0.tar", "path/to/1.tar"]

		self.batch_size = 16
		self.num_workers = 12

	def get_pipeline(self, data_dir):
		datapipe = torchdata.datapipes.iter.IterableWrapper(data_dir)\
			.load_from_tar() \
			.batch(2) \
			.map(self.to_sampels) \
			.batch(self.batch_size) \
			# .map(self.collate_fn)
		return datapipe
	def setup(self, stage=None):
		self.train = self.get_pipeline(self.urls)

	def train_dataloader(self):
		service = [
		reading_service = SequentialReadingService(*service)
		return DataLoader2(self.train, reading_service=reading_service)

	def to_sampels(self, data):
		a, t = data
		return[1].read())), json.loads(t[1].read().decode('utf-8'))

def main(devices=1, max_epochs=1, strategy='ddp'):
	model = PLModel()
	datamodule = PL_datamodule()
	trainer = pl.Trainer(devices=devices, max_epochs=max_epochs, strategy=strategy), datamodule=datamodule)

if __name__ == '__main__':
	main(devices=2, max_epochs=1, strategy='ddp')

Furthermore, regarding what was causing the slowdown in my original post, This was caused by the large default buffer_size of the shuffle after the sharding_filter. Setting this to a smaller number helped.

I am doing almost exactly the same thing but running into an error when loading from tar - “cannot pickle ‘ExFileObject’ object”. Did you run into this knoriy?

Yes, I have. Can you share your pipeline here?

I am no longer running into that error, but I am running into another issue. Help here would be greatly appreciated.

The new problem is that sometimes the fetch time is horrendous. I think what’s happening is that there are times when one worker has to do all of the work. The reason I think that is because the mean/median/max of a single worker (num_workers = 1) operating is 30.3s / 31.0s / 40s. When I use num_workers=4, then mean=4-5s, median~0.001s, and it will spike to max ~40-50.

I’m pretty sure that the big issue is the load_from_tar(). When I use just one worker, that and groupby (plus the batch size at the bottom) is enough to spike the timing.

Do you have any ideas for fixing this? I tried adding a sharding_filter just before the load_from_tar but that made it worse → median=0.0018754005432128906 mean=5.149842675139264. I tried adding it just after the groupby and then it slows down precipitously.

def get_chunked_dataloader(datapipe, num_workers):
  # The datapipe is the result of get_chunked_dataset below.
  mp = MultiProcessingReadingService(num_workers=num_workers)
  return DataLoader2(datapipe, reading_service=mp)  

def get_chunked_dataset(split, batch_size=100):
  assert split in ['train', 'val', 'test']
  bucket = "s3://<BUCKET>/%s" % split
  dp = IterableWrapper([bucket]).list_files_by_fsspec()
  dp = dp.shuffle()
  dp = dp.open_files_by_fsspec(mode="rb")
  dp = dp.prefetch(5)
  dp = dp.load_from_tar() # "r"
  dp = dp.groupby(groupby, group_size=2, guaranteed_group_size=2)
  dp =
  # ExpandDataPipe is an IterDataPipe that takes in a source data pipe and 
  # makes ~3-4 examples from each small wav + label combo it receives. 
  # It's reasonably fast.
  dp = ExpandDataPipe(dp)
  dp = dp.shuffle(buffer_size=20)
  dp = dp.batch(batch_size)
  dp =
  return dp 

def process(row):
    if row[0][0].endswith('json'):
        stream_json_wrapper = row[0][1]
        stream_wav_wrapper = row[1][1]
        stream_json_wrapper = row[1][1]
        stream_wav_wrapper = row[0][1]
    labels =
    labels = json.loads(labels.decode('utf-8'))
    wav = io.BytesIO(
    wav, _ = torchaudio.load(wav)    
    return wav, labels

def groupby(row):
  return os.path.basename(row[0]).split(".")[0]

@knoriy , mind taking a look? It looks similar to this one → Slow aws S3 data loading using TorchData open_files_by_fsspec

It looks like you’re not using .sharding_filter(). This could mean all the loading is happening in one place and not distributed appropriately.


The docs mention you should place it as early into your pipeline as possible. In your instance, this would be after the first .shuffle()

Give that a try :slight_smile:

I tried that and it didn’t work. GPU power usage was between 60% and 80% the whole time. Without the sharding_filter, it’s much higher and bounces down to the 60s only when I’m loading in the val set. Any other ideas?

hi all. this thread is not very young but… (hi @cinjon! aren’t you … the Cinjon with audio research, also KC’s student? :wave: )

did @cinjon or @knoriy figure out how to configure a datapipe with fast S3 loading? i would appreciate it very, very much if you can share your working setup :pray:

hehe hi! nice to see you here.

i gave up on this approach and went with mosaic’s streamer. it’s quite good and did/does the job well → GitHub - mosaicml/streaming: A Data Streaming Library for Efficient Neural Network Training.

1 Like

haha :smiley:
gotcha, i’ve been debating so much about Torchdata… thanks for another datapoint, @cinjon!