Efficient way to create Dataset object from synthetic data

Hi folks,

I have some function sample_fn creating some data on the fly calling a lot of torch.rand functions inside it.

What I’d like to do is create a Dataset object from it ensuring:

  1. It is efficient in terms of generating and loading the data

  2. Ensures that there’s some diversity in the generated data (e.g., using torch.rand without passing a rng generator on every call).

So far, my attempt is the following:

#set seed for reproducibility


# do some other stuff

# create dataset
class SyntheticDataset(tud.Dataset):

    def __init__(self, sample_fn, nb_samples, seq_len):
        self.nb_samples = nb_samples
        self.samples = sample_fn(nb_samples, seq_len) # do I need to pass rng here for 2?

    def __getitem__(self, idx):
        X = self.samples['input'][idx]
        y = self.samples['output'][idx]
        return X, y

    def __len__(self):
        return len(self.y)

# train model

Is there a better way to achieve those two goals?