Iterable Dataset Reading from Disk - Hangs DDP

I am trying to train a model (On AWS Sagemaker) using DDP. I am reading from multiple parquet files with an iterable dataset and attempting to split the reads up between workers and ranks.

class iterDataset_train(IterableDataset):
    def __init__(self, dirpath, args, force_non_dist = False):
        self.dirpath = dirpath
        self.dim_reshape_time = args.dim_reshape_time
        self.dim_reshape_features = args.dim_reshape_features
        self.world_size = args.world_size
        self.rank = args.rank
        
        if force_non_dist:
            self.is_distributed = False
            
            
        else:
            self.is_distributed = args.is_distributed
        

    def iter_parquet(self):
        
        worker_total_num = get_worker_info().num_workers
        worker_id = get_worker_info().id
        
        if self.is_distributed:
            world_size = self.world_size
            rank = self.rank
        else:
            world_size = 1
            rank = 0

        worker_total_num *= world_size
        global_worker_id = worker_id * world_size + rank
        
        lst_glob = glob(f'{self.dirpath}/*.parquet')
        lst_splits = np.array_split(lst_glob, worker_total_num)

        for fpath in lst_splits[global_worker_id].tolist():
        
            tbl = pq.ParquetFile(fpath)

            for group_i in range(tbl.num_row_groups): # for each group of rows 

                row_group = tbl.read_row_group(group_i) # read group

                for batch in row_group.to_batches(): # loop through batches of the group  
                    for row in zip(*batch.columns): # loop through the rows of the batch and combine the columns
                        
                        features = row[1].values.to_numpy(zero_copy_only=False, writable=True).reshape((self.dim_reshape_time, self.dim_reshape_features)).astype(np.float64)
                        targets = row[2].values.to_numpy(zero_copy_only=False, writable=True).astype(np.float64)
                        
                        yield features, targets
                        
                    
    def __iter__(self):
        return self.iter_parquet()

The data loader looks like this:

    train_ds = iterDataset_train(args.train_dir, args)
  

    train_loader = DataLoader(train_ds, shuffle=False, batch_size=args.batch_size_train,drop_last=True, num_workers = args.num_workers)

I am seeing the script hang at various places. Before delving into debugging, I am wondering if how I am iterating through the data is an issue? I have seen various things that suggest each rank (GPU or maybe node with multi-node training?) needs to get the same number of batches. Is this indeed the case and perhaps what is causing my hangs? Is there a way with the above to get this to be the case?

In my case, I can know beforehand the number of rows of the data and the number of parquet files. I could also, instead of splitting the parquet files between workers, assigns rows like is done here: DDP + Torchdata - #3 by ejguan

Finally, I am fine with throwing out some batches as the training set is very large and wont be missed.

I implemented the approach here and then if I figured out how many batches each process (GPU) would get and take the minimum, restricted any process from forward/backward any batches in an epoch over that amount…all the hangs went away!