For instance, in the get_batch, you could do:
def get_batch(self):
if self.batch_idx==self.max_idx-1:
batch=self.dataset[self.indexer[self.batch_idx*self.batch_size:],...]
self.batch_idx=0
else:
batch=self.dataset[self.indexer[self.batch_idx*self.batch_size:self.batch_idx*self.batch_size+self.batch_size],...]
self.batch_idx+=1
batch=self.process_batch(batch.to(device)) #send it to the GPU for faster processing
return batch
And then you can define your transforms inside of that definition:
def process_batch(self, data):
data = data / 255. #normalize images
#other transforms
return data
Of course, all of the above would be in a class function.