How to get DistributedRandomSampler to use a new random order in every epoch?

Hi everyone,

I am trying to use DistributedRandomSampler in the C++ frontend. However, it is not behaving as I would expect: it uses the same random order of examples in every epoch (despite set_epoch()). Probably I am doing something wrong, but I was not able to find what it is. The Python equivalent of the C++ code does what I expect. Can you help me find the problem?

Example:

#include <torch/torch.h>
#include <stdio.h>
#include <stdlib.h>

int main(int argc, char *argv[]) {
  int N = 8;
  auto inputs = torch::arange(N).view({N, 1});
  auto dataset = torch::data::datasets::TensorDataset({inputs});

  // this works as expected: new random order in every epoch
  // torch::data::samplers::RandomSampler sampler (dataset.size().value());

  // this does not: same random order in every epoch
  torch::data::samplers::DistributedRandomSampler sampler (dataset.size().value(), /*num_replicas=*/2, /*rank=*/0);

  auto loader = torch::data::make_data_loader(
                  dataset,
                  sampler,
                  torch::data::DataLoaderOptions().batch_size(2));


  for (unsigned int epoch=0; epoch!=3; ++epoch) {
    std::cout << "====== epoch " << epoch << "\n";

    sampler.set_epoch(epoch);
    // sampler.reset(); // also tried to manually reset the sampler here, but this did not help

    unsigned long batch_idx = 0;
    for (auto& batch : *loader) {
      std::cout << "batch " << batch_idx << ": ";
      for (auto& example : batch) {
        std::cout << example.data[0].item<float>() << " ";
      }
      std::cout << "\n";
      ++batch_idx;
    }
  }

  return 0;
}

This produces batches with the same random order in every epoch:

====== epoch 0
batch 0: 0 2 
batch 1: 1 5 
====== epoch 1
batch 0: 0 2 
batch 1: 1 5 
====== epoch 2
batch 0: 0 2 
batch 1: 1 5 

I am trying to have different orders in every epoch. I.e., I would expect the output to look like this:

====== epoch 0
batch 0: 4 7
batch 1: 2 1
====== epoch 1
batch 0: 5 2
batch 1: 7 1
====== epoch 2
batch 0: 0 7
batch 1: 6 1

The following Python code produces the output that I expect:

import torch

N = 8
inputs = torch.arange(N, dtype=torch.float32).view(N, 1)
dataset = torch.utils.data.TensorDataset(inputs)

sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler)

for epoch in range(3):
    sampler.set_epoch(epoch)
    print("====== epoch " + str(epoch))
    for batch_idx, (input,) in enumerate(loader):
        print("batch " + str(batch_idx) + ": " + " ".join([str(int(x)) for x in input.squeeze().tolist()]))

I am using PyTorch version 1.10.0.

Thanks in advance for any help!

Cheers
Alexander

That’s indeed weird as I would expect set_epoch would behave the same in libtorch.
Would you mind creating an issue on GitHub so that we could track it, please?

Thanks for the quick reply! See libtorch: `DistributedRandomSampler` uses the same random order in every epoch · Issue #73141 · pytorch/pytorch · GitHub.