Just a suggestion: don’t use the vanilla dataloader for this. If you can fit the data all into ram, just turn it into a tensor for each set of data, shape being:
self.dataset.size()=([data_index, ...])
Where each data_index is a sample.
Then create an indexer, either in NumPy or Pytorch. Here is an example for NumPy.
rng=np.random.default_rng()
...
self.data_index=np.arange(self.dataset.size()[0])
Then define a shuffler class you can call whenever you want to shuffle the index:
def shuffler(self):
rng.shuffle(self.data_index)
We need to define a batch_size, batch_idx, and max_idx in the init:
self.batch_size=batch_size
self.max_idx=self.dataset.size()[0]//self.batch_size
self.batch_idx=0
Now we can define a get_batch function:
def get_batch(self):
if self.batch_idx==self.max_idx-1:
batch=self.dataset[self.indexer[self.batch_idx*self.batch_size:],...]
self.batch_idx=0
else:
batch=self.dataset[self.indexer[self.batch_idx*self.batch_size:self.batch_idx*self.batch_size+self.batch_size],...]
self.batch_idx+=1
return batch
That should get you started, at least.
For an additional speed boost, you can load the dataset to a spare GPU and load/process from there.