I’m working with tar file streamed from S3 bucket. I find that the current Datapipe is slow at fetching all the data. The delay grows with more URLs (data_dir).
I’ve tried moving the sharding_filter before load_from_tar, but this results in undesired behaviour, such that when more GPUs are added, the sharding fails, and each GPU receives the full data instead of batch/GPUs
datapipe = (
torchdata.datapipes.iter.IterableWrapper(data_dir)
.shuffle()
.open_files_by_fsspec(mode='rb')
.load_from_tar()
.enumerate() # Assign an index to each sample
.sharding_filter() # Filter samples based on their index and the GPU rank
.shuffle()
.map(self.to_samples)
.batch(self.batch_size)
.map(self.collate_fn)
)
service = [
DistributedReadingService(),
MultiProcessingReadingService(num_workers=self.num_workers),
]
reading_service = SequentialReadingService(*service)
dataloader = DataLoader2(datapipe, reading_service=reading_service)
I don’t think this will work. The Datapipe is loading the full dataset before sharding_filter is applied. ideally load_from_tar should be performed after sharding
This is likely a bug or an order of operation issue.
You almost certainly want to use .sharding_filter() immediately after your first shuffle(). Otherwise, you will be duplicating work across workers (such as open_files_by_fsspec, load_from_tar, batch(2)).
but this results in undesired behaviour, such that when more GPUs are added, the sharding fails, and each GPU receives the full data instead of batch/GPUs
As for this, can you further elaborate? Perhaps provide something reproducible?
For example, it would be useful to check torch.distributed.get_world_size() and torch.distributed.get_rank() as those will be used for sharding by DistributedReadingService within DataLoader2.
I agree with your statement. Unfortunately, this results in unexpected behaviour; these are:
when .shardfilter is above each batch(2), GPU received the full dataset and not batch/gpu, I fixed this by changing the order of the reading service ([MPRS, DRS] >> [DRS, MPRS]) and moving .sharding_filter below .batch(2).
Using MPRS and them DRS resulted in getting can't pickle: ExFileObject
Re:
I fixed this by reordering the reading services. What is the significance of the order by which reading services are passed to dataloader2?
Here is a simple working example: using pytorch==2.1.0 and pytorch_lightning=2.0. Where "path/to/*.tar" (line 21) is a path to an s3 bucket holding tar files containing *.wav and its corresponding*.json label.
Furthermore, regarding what was causing the slowdown in my original post, This was caused by the large default buffer_size of the shuffle after the sharding_filter. Setting this to a smaller number helped.
I am doing almost exactly the same thing but running into an error when loading from tar - “cannot pickle ‘ExFileObject’ object”. Did you run into this knoriy?
I am no longer running into that error, but I am running into another issue. Help here would be greatly appreciated.
The new problem is that sometimes the fetch time is horrendous. I think what’s happening is that there are times when one worker has to do all of the work. The reason I think that is because the mean/median/max of a single worker (num_workers = 1) operating is 30.3s / 31.0s / 40s. When I use num_workers=4, then mean=4-5s, median~0.001s, and it will spike to max ~40-50.
I’m pretty sure that the big issue is the load_from_tar(). When I use just one worker, that and groupby (plus the batch size at the bottom) is enough to spike the timing.
Do you have any ideas for fixing this? I tried adding a sharding_filter just before the load_from_tar but that made it worse → median=0.0018754005432128906 mean=5.149842675139264. I tried adding it just after the groupby and then it slows down precipitously.
def get_chunked_dataloader(datapipe, num_workers):
# The datapipe is the result of get_chunked_dataset below.
mp = MultiProcessingReadingService(num_workers=num_workers)
return DataLoader2(datapipe, reading_service=mp)
def get_chunked_dataset(split, batch_size=100):
assert split in ['train', 'val', 'test']
bucket = "s3://<BUCKET>/%s" % split
dp = IterableWrapper([bucket]).list_files_by_fsspec()
dp = dp.shuffle()
dp = dp.open_files_by_fsspec(mode="rb")
dp = dp.prefetch(5)
dp = dp.load_from_tar() # "r"
dp = dp.groupby(groupby, group_size=2, guaranteed_group_size=2)
dp = dp.map(process)
# ExpandDataPipe is an IterDataPipe that takes in a source data pipe and
# makes ~3-4 examples from each small wav + label combo it receives.
# It's reasonably fast.
dp = ExpandDataPipe(dp)
dp = dp.shuffle(buffer_size=20)
dp = dp.batch(batch_size)
dp = dp.map(collate)
return dp
def process(row):
if row[0][0].endswith('json'):
stream_json_wrapper = row[0][1]
stream_wav_wrapper = row[1][1]
else:
stream_json_wrapper = row[1][1]
stream_wav_wrapper = row[0][1]
labels = stream_json_wrapper.read()
labels = json.loads(labels.decode('utf-8'))
wav = io.BytesIO(stream_wav_wrapper.read())
wav, _ = torchaudio.load(wav)
return wav, labels
def groupby(row):
return os.path.basename(row[0]).split(".")[0]
I tried that and it didn’t work. GPU power usage was between 60% and 80% the whole time. Without the sharding_filter, it’s much higher and bounces down to the 60s only when I’m loading in the val set. Any other ideas?