dataset = TensorDataset(torch.arange(10))
train_dataset = torch.utils.data.Subset(dataset, indices=torch.arange(5))
val_dataset = torch.utils.data.Subset(dataset, indices=torch.arange(5, 10))
for d in train_dataset:
print(d)
for d in val_dataset:
print(d)