How to correctly upload data into the DataLoader

I have a dataset with image galaxies and corresponding labels as a .h5 file, that I want to train a classification network with. At first, I want to normalize the images and convert it to grayscale. Afterwards, the data shall be splitted into a training set and a validation set. So far my code looks as follows:

import h5py
import numpy as np
import matplotlib.pyplot as plt
import random
import torch

# Load dataset
f = h5py.File('Galaxy10.h5', 'r')

# The dataset contains two keys: ans and images
# ans represent the labels with shape (21785,) and the images have shape (21785, 69, 69, 3)
labels, images = f['ans'], f['images']

# As a next step, we want to convert the RGB images to grayscale, normalize them and set the std deviation to 1
images = np.dot(images[..., :3], [0.2989, 0.5870, 0.1140])
images = (images - np.mean(images)) / np.std(images)

from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

# At first, we want to split our dataset into a training set and a validation set
validation_split = .2

# Creating data indices for training and validation splits:
dataset_size = len(images)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

How do I correctly load the images with the labels into the DataLoader now? I assume, that I have to make two separate DataLoaders for training set and validation set.

Thanks for your help!

from torch.utils.data import TensorDataset, DataLoader

BATCH_SIZE = 512

tensor_x = torch.Tensor(images[train_indices])
tensor_y = torch.Tensor(labels[train_indices])
dataset_train = TensorDataset(tensor_x, tensor_y)

train_loader = DataLoader(dataset_train, batch_size=BATCH_SIZE)

you can iterate

for i, (x, y) in enumerate(train_loader):
print(“x.shape {}, y.shape {}”.format(x.shape, y.shape))