I have a custom datapipe for sharding that I use instead of the default sharding_filter
, etc. It roughly looks like this:
def __iter__(self):
from distributed_ml.logging.logger import logger
l = logger.bind(tag="FullyShardedChunkWrapper")
l.debug("iter_start", epochs=self.epochs)
self.epochs += 1
rng = random.Random(self.epochs)
all_chunks = copy.deepcopy(self.chunks)
if self.global_chunk_shuffle:
rng.shuffle(all_chunks)
world_info = self.get_world_info()
cur_rank_chunks = partition(all_chunks, world_info.size, world_info.rank)
worker_id, total_workers = self.get_worker_info()
cur_worker_chunks = partition(cur_rank_chunks, total_workers, worker_id)
rng.shuffle(cur_worker_chunks)
total = len(cur_worker_chunks)
for idx, chunk in enumerate(cur_worker_chunks):
l.debug("iter_yield", idx=idx, left=total - idx - 1)
yield chunk
I recently had the discovery that the my datapipe was not actually shuffling the data when a new epoch rolled around because self.epochs
was never incremented.
My guess here is that, when a new epoch happens, the dataloader re-creates the datapipe from scratch, so maintaining / incrementing instance variables is a bad idea.
But, I wanted to confirm this intuition.
@ejguan