Hi folks,
I have some function sample_fn
creating some data on the fly calling a lot of torch.rand
functions inside it.
What I’d like to do is create a Dataset
object from it ensuring:
-
It is efficient in terms of generating and loading the data
-
Ensures that there’s some diversity in the generated data (e.g., using
torch.rand
without passing arng
generator on every call).
So far, my attempt is the following:
#set seed for reproducibility
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# do some other stuff
# create dataset
class SyntheticDataset(tud.Dataset):
def __init__(self, sample_fn, nb_samples, seq_len):
self.nb_samples = nb_samples
self.samples = sample_fn(nb_samples, seq_len) # do I need to pass rng here for 2?
def __getitem__(self, idx):
X = self.samples['input'][idx]
y = self.samples['output'][idx]
return X, y
def __len__(self):
return len(self.y)
# train model
Is there a better way to achieve those two goals?