I am trying to train a model (On AWS Sagemaker) using DDP. I am reading from multiple parquet files with an iterable dataset and attempting to split the reads up between workers and ranks.
class iterDataset_train(IterableDataset):
def __init__(self, dirpath, args, force_non_dist = False):
self.dirpath = dirpath
self.dim_reshape_time = args.dim_reshape_time
self.dim_reshape_features = args.dim_reshape_features
self.world_size = args.world_size
self.rank = args.rank
if force_non_dist:
self.is_distributed = False
else:
self.is_distributed = args.is_distributed
def iter_parquet(self):
worker_total_num = get_worker_info().num_workers
worker_id = get_worker_info().id
if self.is_distributed:
world_size = self.world_size
rank = self.rank
else:
world_size = 1
rank = 0
worker_total_num *= world_size
global_worker_id = worker_id * world_size + rank
lst_glob = glob(f'{self.dirpath}/*.parquet')
lst_splits = np.array_split(lst_glob, worker_total_num)
for fpath in lst_splits[global_worker_id].tolist():
tbl = pq.ParquetFile(fpath)
for group_i in range(tbl.num_row_groups): # for each group of rows
row_group = tbl.read_row_group(group_i) # read group
for batch in row_group.to_batches(): # loop through batches of the group
for row in zip(*batch.columns): # loop through the rows of the batch and combine the columns
features = row[1].values.to_numpy(zero_copy_only=False, writable=True).reshape((self.dim_reshape_time, self.dim_reshape_features)).astype(np.float64)
targets = row[2].values.to_numpy(zero_copy_only=False, writable=True).astype(np.float64)
yield features, targets
def __iter__(self):
return self.iter_parquet()
The data loader looks like this:
train_ds = iterDataset_train(args.train_dir, args)
train_loader = DataLoader(train_ds, shuffle=False, batch_size=args.batch_size_train,drop_last=True, num_workers = args.num_workers)
I am seeing the script hang at various places. Before delving into debugging, I am wondering if how I am iterating through the data is an issue? I have seen various things that suggest each rank (GPU or maybe node with multi-node training?) needs to get the same number of batches. Is this indeed the case and perhaps what is causing my hangs? Is there a way with the above to get this to be the case?
In my case, I can know beforehand the number of rows of the data and the number of parquet files. I could also, instead of splitting the parquet files between workers, assigns rows like is done here: DDP + Torchdata - #3 by ejguan
Finally, I am fine with throwing out some batches as the training set is very large and wont be missed.