TorchData: Sharding, Caching, and Prefetching

Right now, my DataPipe and DataLoader2 look like the following. They are reading data from R2 using an AWS-like bucket_path with an endpoint_url specification. This is surprisingly slow and I think I must be doing something wrong.

Some questions:

  1. If I was to do an on-disk cache, where should I put it in the pipeline? I tried putting an on_disk_cache before the open_file_by_fsspec and then the end_caching after it, but that stalled the pipeline completely.
  2. The prefetch(5) doesn’t actually seem to do anything. Am I ussnig it wrong?
  3. I want this pipeline to continue indefinitely. How do I use repeat with it? Do I put it … afterwards? That seems inefficient.
dp = IterableWrapper([bucket_path]).list_files_by_fsspec()
dp = dp.sharding_filter()    
dp = dp.open_files_by_fsspec(mode="rb")
dp = dp.prefetch(5)
dp = dp.load_from_tar()  
dp = dp.groupby(groupby, group_size=2, guaranteed_group_size=2)
dp = dp.map(process)
dp = ExpandDataIPUPipe(dp, predict_ms=predict_ms)
dp = dp.shuffle(buffer_size=10)
dp = dp.batch(batch_size)
dp = dp.map(collate)
mp = MultiProcessingReadingService(num_workers=num_workers) # typically 12
dl = DataLoader2(dp, reading_service=mp)

def collate(batch):
  wavs = torch.cat([k[0].unsqueeze(0) for k in batch], axis=0)
  labels = [k[1] for k in batch]
  labels = torch.tensor(labels).to(torch.int64)
  if wavs.shape[1] == 2:
    wavs = wavs.mean(1)
  return wavs, labels

def groupby(row):
  return os.path.basename(row[0]).split(".")[0]

def process(row):
  if row[0][0].endswith('json'):
      stream_json_wrapper = row[0][1]
      stream_wav_wrapper = row[1][1]
  else:
      stream_json_wrapper = row[1][1]
      stream_wav_wrapper = row[0][1]
  labels = stream_json_wrapper.read()
  labels = json.loads(labels.decode('utf-8'))
  wav = io.BytesIO(stream_wav_wrapper.read())
  wav, _ = torchaudio.load(wav)    
  return wav, labels