Multiprocessing Data Loading producing errors

Hi,

For the data loader in AWS Sagemaker, it runs fine when num_workers = 0, but produces this error when num_workers>0.

Traceback (most recent call last):
  File "main.py", line 371, in <module>
    g_scaler=g_scaler, d_scaler=d_scaler, runtime_log_folder=runtime_log_folder, runtime_log_file_name=runtime_log_file_name)
  File "main.py", line 78, in train_fn
    for idx, (x, y) in enumerate(loop):
  File "/opt/conda/lib/python3.6/site-packages/tqdm/std.py", line 1171, in __iter__
    for obj in iterable:
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 525, in __next__
    (data, worker_id) = self._next_data()
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1252, in _next_data
    return (self._process_data(data), w_id)
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1299, in _process_data
    data.reraise()
  File "/opt/conda/lib/python3.6/site-packages/torch/_utils.py", line 429, in reraise
    raise self.exc_type(msg)
  File "/opt/conda/lib/python3.6/site-packages/botocore/exceptions.py", line 84, in __init__
    super(HTTPClientError, self).__init__(**kwargs)
  File "/opt/conda/lib/python3.6/site-packages/botocore/exceptions.py", line 40, in __init__
    msg = self.fmt.format(**kwargs)
KeyError: 'error'

---------------------------------------------------------------------------
UnexpectedStatusException                 Traceback (most recent call last)
<ipython-input-1-81655136a841> in <module>
     58                             py_version='py3')
     59 
---> 60 pytorch_estimator.fit({'train': Runtime.dataset_path}, job_name=Runtime.job_name)
     61 
     62 #print(pytorch_estimator.latest_job_tensorboard_artifacts_path())

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/sagemaker/estimator.py in fit(self, inputs, wait, logs, job_name, experiment_config)
    955         self.jobs.append(self.latest_training_job)
    956         if wait:
--> 957             self.latest_training_job.wait(logs=logs)
    958 
    959     def _compilation_job_name(self):

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/sagemaker/estimator.py in wait(self, logs)
   1954         # If logs are requested, call logs_for_jobs.
   1955         if logs != "None":
-> 1956             self.sagemaker_session.logs_for_job(self.job_name, wait=True, log_type=logs)
   1957         else:
   1958             self.sagemaker_session.wait_for_job(self.job_name)

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/sagemaker/session.py in logs_for_job(self, job_name, wait, poll, log_type)
   3751 
   3752         if wait:
-> 3753             self._check_job_status(job_name, description, "TrainingJobStatus")
   3754             if dot:
   3755                 print()

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/sagemaker/session.py in _check_job_status(self, job, desc, status_key_name)
   3304                 ),
   3305                 allowed_statuses=["Completed", "Stopped"],
-> 3306                 actual_status=status,
   3307             )
   3308 

UnexpectedStatusException: Error for Training job 2022-06-03-05-16-49-pix2pix-U12239-2022-05-09-14-39-18-training: Failed. Reason: AlgorithmError: ExecuteUserScriptError:
Command "/opt/conda/bin/python3.6 main.py --runtime_var dataset_name=U12239-2022-05-09-14-39-18,job_name=2022-06-03-05-16-49-pix2pix-U12239-2022-05-09-14-39-18-training,model_name=pix2pix"

  0%|          | 0/248 [00:00<?, ?it/s]
  0%|          | 1/248 [00:30<2:07:28, 30.97s/it]
  0%|          | 1/248 [00:30<2:07:28, 30.97s/it]
Traceback (most recent call last):
  File "main.py", line 371, in <module>
    g_scaler=g_scaler, d_scaler=d_scaler, runtime_log_folder=runtime_log_folder, runtime_log_file_name=runtime_log_file_name)
  File "main.py", line 78, in train_fn
    for idx, (x, y) in enumerate(loop):
  File "/opt/conda/lib/python3.6/site-packages/tqdm/std.py", line 1171, in __iter__
    for obj in iterable:
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 525, in __next__
    (data, worker_id) = self._next_data()
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1252, in _next_data
    return (self

I am using a map style dataset from AWS S3, but it seems that the data loader produces error when using multiprocessing data loading, but is fine when single processing data loading.

Does this mean that I should change to an iterable style dataset in order to make multiprocessing data loading work? Or is there a way to make multiprocessing data loading work with a map style dataset like mine?

Have you ever tried new version of PyTorch? And, I can’t find useful log in your traceback. I would be helpful if you can provide some code snippets.

It seems the Error comes from boto. You might want to open an issue from AWS S3 to get help.

Here is the code that produces the error

def train_fn(disc, gen, loader, opt_disc, opt_gen, l1, bce, g_scaler, d_scaler,runtime_log_folder,runtime_log_file_name):

    total_output=''
    
    loop = tqdm(loader, leave=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    print("Loop")
    print(loop)
    print("Length loop")
    print(len(loop))
    for idx, (x, y) in enumerate(loop): #<--error happens here
        print("Loop index")
        print(idx)
        print("Loop item")
        print(x,y)
        x = x.to(device)
        y = y.to(device)
        
        # train discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)

            D_real = disc(x, y)
            D_fake = disc(x, y_fake.detach())
            # use detach so as to avoid breaking computational graph when do optimizer.step on discriminator
            # can use detach, or when do loss.backward put loss.backward(retain_graph = True)

            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake_loss = bce(D_fake, torch.ones_like(D_fake))

            D_loss = (D_real_loss + D_fake_loss) / 2
            
            # log tensorboard
            
        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()
        
        # train generator
        with torch.cuda.amp.autocast():
            
            D_fake = disc(x, y_fake)

            # compute fake loss
            # trick discriminator to believe these are real, hence send in torch.oneslikedfake
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))

            # compute L1 loss
            L1 = l1(y_fake, y) * args.l1_lambda

            G_loss = G_fake_loss + L1
            
            # log tensorboard
           
        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
        
        # print epoch, generator loss, discriminator loss
        print(f'[Epoch {epoch}/{args.num_epochs} (b: {idx})] [D loss: {D_loss}, D real loss: {D_real_loss}, D fake loss: {D_fake_loss}] [G loss: ##{G_loss}, G fake loss: {G_fake_loss}, L1 loss: {L1}]')
        output = f'[Epoch {epoch}/{args.num_epochs} (b: {idx})] [D loss: {D_loss}, D real loss: {D_real_loss}, D fake loss: {D_fake_loss}] [G loss: ##{G_loss}, G fake loss: {G_fake_loss}, L1 loss: {L1}]\n'
        total_output+=output



    runtime_log = get_json_file_from_s3(runtime_log_folder, runtime_log_file_name)
    runtime_log += total_output
    upload_json_file_to_s3(runtime_log_folder,runtime_log_file_name,json.dumps(runtime_log))



def __getitem__(self, index):
    print("Index ",index)
    pair_key = self.list_files[index]
    print("Pair key ",pair_key)
    pair = Boto.s3_client.list_objects(Bucket=Boto.bucket_name, Prefix=pair_key, Delimiter='/')

    input_image_key = pair.get('Contents')[1].get('Key')
    input_image_path = f's3://{Boto.bucket_name}/{input_image_key}'
    print("Input image path ",input_image_path)
    input_image_s3_source = get_file_from_filepath(input_image_path)
    input_image = np.array(Image.open(input_image_s3_source))

    target_image_key = pair.get('Contents')[0].get('Key')
    target_image_path = f's3://{Boto.bucket_name}/{target_image_key}'
    print("Target image path ",target_image_path)
    target_image_s3_source = get_file_from_filepath(target_image_path)
    target_image = np.array(Image.open(target_image_s3_source))

    augmentations = config.both_transform(image=input_image, image0=target_image)

    # get input image and target image by doing augmentations of images
    input_image, target_image = augmentations['image'], augmentations['image0']

    input_image = config.transform_only_input(image=input_image)['image']
    target_image = config.transform_only_mask(image=target_image)['image']
    
    print("Input image size ",input_image.size())
    print("Target image size ",target_image.size())
    
    return input_image, target_image

Here are a summary of traces of the failure points from logs

i) 2022-06-03-05-00-04-pix2pix-U12239-2022-05-09-14-39-18-training
No index shown
[Epoch 0/100 (b: 0)]

ii) 2022-06-03-05-16-49-pix2pix-U12239-2022-05-09-14-39-18-training
Index  160
[Epoch 0/100 (b: 0)]

iii) 2022-06-03-05-44-46-pix2pix-U12239-2022-05-09-14-39-18-training
Index  160
[Epoch 0/100 (b: 0)]

iv) 2022-06-03-06-08-33-pix2pix-U12239-2022-05-09-14-39-18-training
Index  160
[Epoch 1/100 (b: 0)]

v) 2022-06-15-02-49-20-pix2pix-U12239-2022-05-09-14-39-18-training
Index  160
Pair key  datasets/training-data/testing/2022-05-09-14-39-18/match-raws-finals/U12239/P423712/Pair_71/
[Epoch 0/100 (b: 0)

vi) 2022-06-15-02-59-43-pix2pix-U12239-2022-05-09-14-39-18-training
Index  64
Pair key  datasets/training-data/testing/2022-05-09-14-39-18/match-raws-finals/U12239/P425642/Pair_27/
[Epoch 0/100 (b: 247)]

vii) 2022-06-15-04-49-33-pix2pix-U12239-2022-05-09-14-39-18-training
Index  64
Pair key  datasets/training-data/testing/2022-05-09-14-39-18/match-raws-finals/U12239/P415414/Pair_124/
No specific epoch 

Essentially it seems to fail either at the start of the epoch(batch 0) or end of epoch(batch 247), and due to multiprocessing.

I’ve tried updating Pytorch version, but the error still happens. I did mention to AWS support also, but their suggestion/solution was to do it on single processing :woozy_face:, which isn’t really good for me as it took 1.5 hour/epoch for single processing (num workers = 0), compared to 40 mins/epoch for multi processing (num workers = 2).

I have also attached the code and logs summary in the earlier reply.

boto client is not picklable so you can’t create this client before multiprocessing starts.

My recommendation would be constructing boto client lazily in your Dataset:

  1. Put None as the client in your dataset __init__ function as a placeholder.
  2. In the __getitem__ function, create client object if the client is None.

Then, you should be able to run your code with multiprocessing.

1 Like

I did that and it worked thanks