Best Practice for sampling from in-memory data

Hello there,
I am searching for a best practice regarding the sampling of in-memory data.
Let’s assume there exists a tensor with shape [10, 5000, 1, 224, 224] where 10 is the number of models, 5000 the number of samples and the last three dimensions are the dimensions per sample. The task is to sample from this tensor so that a cartain number of samples are presented e.g. to a model ensemble. A standard custom random batch sampler would look something like that:

class CustomSampler:

  def __init__(self, data, labels, batch_size):
  def get_index(self, ...):
          return torch.randint(0, n_samples, batch_size)

   def __next__(self):
          batch_idx = self.get_index()
          img =[:, batch_idx, :]
          labels = self.labels[:, batch_idx]
          return (img, labels)

Note that data and labels are created by going through a big preprocessing pipeline resulting in the tensors already having been moved to the GPU.

Is there a best/ better practice than the above mentioned approach to sample from data that was already loaded in memory and exists as one large tensor?

Looking forward to you answers.