We are using torch.data lib to load a large # of TFRecord files, the code looks like this:
datapipes = []
for path in paths:
datapipe = datapipe.open_files_by_fsspec(mode='rb')
fsspec.asyn.iothread[0] = None
fsspec.asyn.loop[0] = None
datapipe = datapipe.decompress(file_type=compression_type)
datapipe = datapipe.load_from_tfrecord()
datapipe = datapipe.cycle(num_epochs)
datapipes.append(datapipe)
pipes_to_weights_dict = dict(zip(datapipes, dir_weights))
datapipe = SampleMultiplexer(pipes_to_weights_dict)
datapipe = datapipe.map(_load_single_example)
rs = MultiProcessingReadingService(num_workers=num_parallel_calls,
worker_prefetch_cnt=prefetch_size // num_parallel_calls,
main_prefetch_cnt=prefetch_size)
data_loader = DataLoader2(datapipe, reading_service=rs)
for i, e in enumerate(data_loader):
print(time.time())
We found that it takes 10+mins to process the 1st batch. Spending some time adding some debug messages, and found that if I add a print before and inside the loop of this line, it keeps printing the message before the loop, but not inside the loop.
I also tried a dataset with only a few TFRecords, the 1st batch can be processed quickly.
- Can I get some help to narrow down the root cause of long processing time for the 1st batch?
- Seems torch.data is not actively developed, can you share more about the roadmap for Torch data loader? Should we use torch.data lib or something else?
Thanks
Dengpan