Use case for loading the entire dataset into RAM

For instance, in the get_batch, you could do:

def get_batch(self):
    if self.batch_idx==self.max_idx-1:
        batch=self.dataset[self.indexer[self.batch_idx*self.batch_size:],...]
        self.batch_idx=0
    else:
        batch=self.dataset[self.indexer[self.batch_idx*self.batch_size:self.batch_idx*self.batch_size+self.batch_size],...]
        self.batch_idx+=1

    batch=self.process_batch(batch.to(device)) #send it to the GPU for faster processing
    return batch

And then you can define your transforms inside of that definition:

def process_batch(self, data):
    data = data / 255. #normalize images
    #other transforms
    return data

Of course, all of the above would be in a class function.

1 Like