Torch Dataset class won't emit right index from h5 file

Why won’t my Torch dataset class emit the right index from an h5 file?

It is basically emitting the same first batch over and over and over.

# prepare torch data set
class H5DataSet(torch.utils.data.Dataset):
    '''
    This functions loads H5 files for deep learning.

    Parameters
    ----------
    transform : optionable, callable flag
        Whether or not we need to transform the data with a separate class

    Returns
    -------
    self.dataset  : torch tensor
        A tensor containing embeddings with the following
        dimensions: [layers, batch_sz, tokens, features]
    '''
    def __init__(self, path):
        self.file_path = path
        self.dataset = None
        with h5py.File(self.file_path, 'r') as file:
            self.dataset_len = len(file["embeds"])

    def __getitem__(self, index):
        if self.dataset is None:
            # embeds are shaped: (13, 8512, 64, 768)
            self.dataset = h5py.File(self.file_path, 'r')["embeds"][:, index, :, :]
        return self.dataset

    def __len__(self):
        return self.dataset_len

embeddings_ds = H5DataSet(path='D:\\h5py_embeds\\bert_embeds.h5')

# check data
for idx, batch in enumerate(embeddings_ds):
    if idx == 0:
        batch_0 = batch
    if idx == 1:
        batch_1 = batch
        break

assert np.array_equal(batch_0, batch_1), 'not a match'

Edit: I’ve realized that in order to fix this, I need to change the way I build my H5 file in the first place. So, since I want to emit individual batches, I need to ensure that the first dimension in my H5 shape is batches. In other words, change my structure from:

[layers, batch_sz, tokens, features]
to
[batch_sz, layers, tokens, features]