Resuming RandomSampler from an intermediate sample in a deterministic way

I would like to know if we can resume the iteration over a RandomSampler, from a given point in a deterministic way, i.e., stop execution in the middle of the iteration of a torch.utils.data.RandomSampler object, then saving RNG state, and later load RNG state in a way that the execution is equivalent to the original program, had we not stopped the execution. The following code sample summarizes the issue

import torch
from torch.utils.data import TensorDataset, RandomSampler

print(torch.__version__)  # '1.6.0'
torch.manual_seed(0)
data = TensorDataset(torch.arange(10))  # [0, ..., 9]

# we will sample two random numbers with replacement
rand_samp = RandomSampler(data, replacement=True, num_samples=2)

state_0 = torch.get_rng_state() # save initial random state
for x in rand_samp:  # we print the two numbers
    print(x)  # 4; 9

torch.set_rng_state(state_0)  # go back to previous state

for x in rand_samp:  # print the first number and stop the sampler
    print(x)   # 4
    state_1 = torch.get_rng_state()  # try to save state after the first batch
    break

# script stops, we try to resume from the next batch

torch.set_rng_state(state_1)
for x in rand_samp:
    print(x)  # prints 3, but expected 9
    break

I think the way RandomSampler currently works, is that it prefetches all samples and then starts yielding one at a time, effectively loosing the ability to stop execution during the for loop and resuming later. One workaround is to create a RandomSampler with n_samples=1 at each iteration, that returns one sample at a time. In that way we can effectively save the RNG state of the underlying RandomSampler, but this is a bit cumbersome and probably inefficient.

I wanted to know if there is some better workaround, other than modifying the code of the RandomSampler class itself. The method to modify would be the iter method in line 113 of pytorch/sampler.py at master · pytorch/pytorch · GitHub

thanks for any help / feedback