I’m working with tar file streamed from S3 bucket. I find that the current Datapipe is slow at fetching all the data. The delay grows with more URLs (data_dir).
I’ve tried moving the sharding_filter before load_from_tar, but this results in undesired behaviour, such that when more GPUs are added, the sharding fails, and each GPU receives the full data instead of batch/GPUs
datapipe = (
torchdata.datapipes.iter.IterableWrapper(data_dir)
.shuffle()
.open_files_by_fsspec(mode='rb')
.load_from_tar()
.enumerate() # Assign an index to each sample
.sharding_filter() # Filter samples based on their index and the GPU rank
.shuffle()
.map(self.to_samples)
.batch(self.batch_size)
.map(self.collate_fn)
)
service = [
DistributedReadingService(),
MultiProcessingReadingService(num_workers=self.num_workers),
]
reading_service = SequentialReadingService(*service)
dataloader = DataLoader2(datapipe, reading_service=reading_service)
I don’t think this will work. The Datapipe is loading the full dataset before sharding_filter is applied. ideally load_from_tar should be performed after sharding
This is likely a bug or an order of operation issue.
You almost certainly want to use .sharding_filter() immediately after your first shuffle(). Otherwise, you will be duplicating work across workers (such as open_files_by_fsspec, load_from_tar, batch(2)).
but this results in undesired behaviour, such that when more GPUs are added, the sharding fails, and each GPU receives the full data instead of batch/GPUs
As for this, can you further elaborate? Perhaps provide something reproducible?
For example, it would be useful to check torch.distributed.get_world_size() and torch.distributed.get_rank() as those will be used for sharding by DistributedReadingService within DataLoader2.
I agree with your statement. Unfortunately, this results in unexpected behaviour; these are:
when .shardfilter is above each batch(2), GPU received the full dataset and not batch/gpu, I fixed this by changing the order of the reading service ([MPRS, DRS] >> [DRS, MPRS]) and moving .sharding_filter below .batch(2).
Using MPRS and them DRS resulted in getting can't pickle: ExFileObject
Re:
I fixed this by reordering the reading services. What is the significance of the order by which reading services are passed to dataloader2?
Here is a simple working example: using pytorch==2.1.0 and pytorch_lightning=2.0. Where "path/to/*.tar" (line 21) is a path to an s3 bucket holding tar files containing *.wav and its corresponding*.json label.
Furthermore, regarding what was causing the slowdown in my original post, This was caused by the large default buffer_size of the shuffle after the sharding_filter. Setting this to a smaller number helped.