Is there a recommended method of obtaining the indices of the dataset that are sampled by a
I know that a RandomSampler will return a list of indices, but there doesn’t seem to be a way to access those indices once the sampler has been passed to a
I don’t know if there is a recommended method but I managed to do it in the way below. This will produce the same indexes until you call sampler.shuffle(). So if you want the indexes, you could just save your sampler and produce the indexes like so ->
indexes = list(iter(sampler)) Don’t know if you need that iter part tbh but can’t test now.
This comes at the downside of manually shuffling your indexes at the end of every epoch via the shuffle method.
Do you think something like this would work for you?
def __init__(self, data_source):
self.data_source = data_source
self.seed = random.randint(0, 2**32 - 1)
n = len(self.data_source)
indexes = list(range(n))
I see the idea - generate a random permutation of the indexes ahead of time, then just pass that to a
Since the order of indexes is fixed I can just extract
batch_size indexes at a time from that list.