Recovering from Dataloader Error / Exception with utils.data.Subset

I recently came up with an approach to recover from a DataLoader issue using a Subset as an alternative to add exception handler for DataLoader when reading a damaged image file · Issue #1137 · pytorch/pytorch · GitHub

Curious about one aspect – the subset doesn’t inherit the original methods from the dataset, so I have to save a copy of the original dataset. Any way of getting around this? Feels a bit messy.

Overview of approach:

    from torch.utils.data import Dataset, DataLoader, Subset

    inference_ds_copy = inference_ds   # Do I require original copy?

    selected = list(range(len(inference_ds_copy)))
    inference_ds = Subset(inference_ds_copy, selected)  # type: ignore
    loader = DataLoader(
        inference_ds,
        batch_size=batch_size,
        num_workers=0,
    )

    total_batches = len(loader)
    loop = iter(tqdm(loader, total=total_batches))

    for batch_idx in range(total_batches):
        try:
            batch = next(loop)

            batch_pred = inference_step(
                ...
            )

            inference_ds_copy.save_prediction_batch(
                ...
            )

        except Exception as e:
            log.info(f"Iteration {batch_idx} | Exception: {e}")

            # rebuild dataloader with remaining subset
            selected = list(range(batch_idx, len(inference_ds_copy)))
            inference_ds = Subset( inference_ds_copy, selected)  # type: ignore
            loader = DataLoader(
                inference_ds,
                batch_size=batch_size,
                num_workers=0,
            )

            # restore progress bar
            pbar = tqdm(loader, total=total_batches)
            pbar.n = batch_idx + 1
            pbar.refresh()
            loop = iter(pbar)