Is there a way to re-permute the indices for SubsetRandomSampler at each epoch without re-initialising the DataLoader at each epoch? I am loading a large dataset so do not want to reload it at every epoch.
my dataloader function is:
def train_dataloader(args):
train_dataset_file = "{0}/train_data_{1}.hdf5".format(args.locations["train_test_datadir"],args.region)
train_dataset = data_io.ConcatDataset("train",args.nlevs, train_dataset_file, args.locations['normaliser_loc'], args.batch_size, xvars=args.xvars,
yvars=args.yvars, yvars2=args.yvars2, samples_frac=args.samples_fraction, data_frac=args.data_fraction, no_norm=args.no_norm)
indices = list(range(train_dataset.__len__()))
train_sampler = torch.utils.data.SubsetRandomSampler(indices)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=None, batch_size=None, sampler=train_sampler, shuffle=False)
return train_loader
And I call it as follows:
def train_loop(model, loss_function, optimizer, scheduler, args):
training_loss = []
train_ldr = train_dataloader(args)
validation_loss = []
test_ldr = test_dataloader(args)
for epoch in range(1, args.epochs + 1):
## Training
train_loss = 0
for batch_idx, batch in enumerate(train_ldr):
# Sets the model into training mode
# print(batch_idx)
model.train()