Hi everyone,
I am trying to use DistributedRandomSampler
in the C++ frontend. However, it is not behaving as I would expect: it uses the same random order of examples in every epoch (despite set_epoch()
). Probably I am doing something wrong, but I was not able to find what it is. The Python equivalent of the C++ code does what I expect. Can you help me find the problem?
Example:
#include <torch/torch.h>
#include <stdio.h>
#include <stdlib.h>
int main(int argc, char *argv[]) {
int N = 8;
auto inputs = torch::arange(N).view({N, 1});
auto dataset = torch::data::datasets::TensorDataset({inputs});
// this works as expected: new random order in every epoch
// torch::data::samplers::RandomSampler sampler (dataset.size().value());
// this does not: same random order in every epoch
torch::data::samplers::DistributedRandomSampler sampler (dataset.size().value(), /*num_replicas=*/2, /*rank=*/0);
auto loader = torch::data::make_data_loader(
dataset,
sampler,
torch::data::DataLoaderOptions().batch_size(2));
for (unsigned int epoch=0; epoch!=3; ++epoch) {
std::cout << "====== epoch " << epoch << "\n";
sampler.set_epoch(epoch);
// sampler.reset(); // also tried to manually reset the sampler here, but this did not help
unsigned long batch_idx = 0;
for (auto& batch : *loader) {
std::cout << "batch " << batch_idx << ": ";
for (auto& example : batch) {
std::cout << example.data[0].item<float>() << " ";
}
std::cout << "\n";
++batch_idx;
}
}
return 0;
}
This produces batches with the same random order in every epoch:
====== epoch 0
batch 0: 0 2
batch 1: 1 5
====== epoch 1
batch 0: 0 2
batch 1: 1 5
====== epoch 2
batch 0: 0 2
batch 1: 1 5
I am trying to have different orders in every epoch. I.e., I would expect the output to look like this:
====== epoch 0
batch 0: 4 7
batch 1: 2 1
====== epoch 1
batch 0: 5 2
batch 1: 7 1
====== epoch 2
batch 0: 0 7
batch 1: 6 1
The following Python code produces the output that I expect:
import torch
N = 8
inputs = torch.arange(N, dtype=torch.float32).view(N, 1)
dataset = torch.utils.data.TensorDataset(inputs)
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler)
for epoch in range(3):
sampler.set_epoch(epoch)
print("====== epoch " + str(epoch))
for batch_idx, (input,) in enumerate(loader):
print("batch " + str(batch_idx) + ": " + " ".join([str(int(x)) for x in input.squeeze().tolist()]))
I am using PyTorch version 1.10.0.
Thanks in advance for any help!
Cheers
Alexander