Use index tricks to avoid second memmap file

Hello everyone,
I have a question about the way pytorch handles indices.
My memmap file is really huge (~70Gb). I want to train my CNN on a certain subset of this memap which contains around half the data points which are not evenly distributed through the data set. I don’t want to create a second memmap file because this would take up much disk space (especially since I already have other memmap files of similar size).
So my idea is to do something like this:

  1. Get the length of my memmap subset (e.g. 10k points) and store the corresponding indices (1,4,7,9, etc.) .
  2. Give my data class the subset length as length of the dataset (10k points)
  3. Whenever the dataloader requests certain indices I simply look up the indices from the stored indices of the subset.

Does something like this work or does it require a different workaround?

Your workflow describes what torch.utils.data.Subset does and sounds reasonable.