How to avoid pickling errors when using PyTorch DDP with S3 data?

Hi,
I’m using Pytorch DDP to data parellize this code GitHub - aws-samples/sagemaker-a2d2-segmentation-pytorch

I’m launching the training with

def main():
    
    # launches 1 data parallel replica per GPU
    mp.spawn(
        fn=train_replica,
        nprocs=world_size,
        join=True)


if __name__ == "__main__":
    
    main()

My dataset class uses boto3 to move records from S3 to memory.
To avoid boto3 serialization issue, the dataset and dataloaders are instantiated from within the train_replica function, so that they should all be created on their own and not boto3 things should need to be serialized.

However, my training fails for


_pickle.PicklingError: Can't pickle <class 'boto3.resources.factory.s3.ServiceResource'>: attribute lookup s3.ServiceResource on boto3.resources.factory failed

How to avoid that? Given how popular S3 is, I guess it is possible to use DDP with a dataset getitem that uses S3 right? how to avoid the serialization problem?

It looks like some multiprocessing code is attempting to pickle some data to send across processes, but boto3 resource can’t be pickled.

Similar to this comment: PicklingError encountered when using multiple GPUs · Issue #67681 · pytorch/pytorch · GitHub, you might want to try:

  1. Passing arguments and construct boto3 client in the subprocess
  2. Use/cache/store results of using boto3 APIs in some format that can be pickled, and pass around the result.