How to save distributedSampler state for resuming later?

Hi,

Let’s say I am using a DistributedSampler for multi-gpu training in the following fashion:

train_sampler = data.distributed.DistributedSampler(train_dataset,
                                                        num_replicas=hvd.size(),
                                                        rank=hvd.rank())

    train_loader = data.DataLoader(dataset=train_dataset,
                                   batch_size=config.dist_train.batch_size,
                                   shuffle=False,
                                   num_workers=args.data_workers,
                                   pin_memory=True,
                                   drop_last=True,
                                   sampler=train_sampler)

During training loop, I set epoch for sampler to ensure reproducibility:

for epoch in range ()...
      train_sampler.set_epoch(epoch - 1)
      for i, _data in enumerate(train_loader):
       ...

However, if in the middle of epoch 3 I checkpoint the model:

        torch.save({
            "epoch": epoch,
            "step": i,
            "total_iter": total_iter,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict()
        }, CHECKPOINT_PATH)

Once I resume it, the sampler will iterate over the same data it has previously seen for epoch 3, instead of remember what it has already iterated and starting from the next unseen data. Is there anyway to do this?

Thanks for your help!

I think that if you fix a seed, and you also restore the last epoch value (and call sampler.set_epoch) you should be good