Indices of a dataset sampled by DataLoader

(Peter Cavén) #1

Is there a recommended method of obtaining the indices of the dataset that are sampled by a torch.utils.data.DataLoader?

I know that a RandomSampler will return a list of indices, but there doesn’t seem to be a way to access those indices once the sampler has been passed to a torch.utils.data.DataLoader.

(Olof Harrysson) #2

Hoi there! :slight_smile:

I don’t know if there is a recommended method but I managed to do it in the way below. This will produce the same indexes until you call sampler.shuffle(). So if you want the indexes, you could just save your sampler and produce the indexes like so -> indexes = list(iter(sampler)) Don’t know if you need that iter part tbh but can’t test now.

This comes at the downside of manually shuffling your indexes at the end of every epoch via the shuffle method.

Do you think something like this would work for you?

class RandomSampler(Sampler):
  def __init__(self, data_source):
    self.data_source = data_source
    self.shuffle()

  def shuffle(self):
    self.seed = random.randint(0, 2**32 - 1)

  def __iter__(self):
    n = len(self.data_source)
    indexes = list(range(n))
    random.Random(self.seed).shuffle(indexes)
    return iter(indexes)

  def __len__(self):
    return len(self.data_source)
(Peter Cavén) #3

OK thanks.
I see the idea - generate a random permutation of the indexes ahead of time, then just pass that to a SequentialSampler.
Since the order of indexes is fixed I can just extract batch_size indexes at a time from that list.

1 Like