Note that the train_dataset is MNIST in my case. These 2 lines occur X numbers of times in a for loop, and they’re slowing down the loop by approximately 3.5 seconds every iteration!
I’m not using the DataLoader class, so I’m trying to see if there’s a more efficient way of reducing the total training dataset at each iteration.
It looks like you are re-assigning the sliced data to your dataset. Could you slice and store it in a temporal variable or do you really need the re-assigning.
I think this might slow down your code.
If you want to use a data loader (which efficiently extracts batches and uses multiprocessing), you can use the get_batch function on train_dataset defined in the code below.
Note that this assumes that an instance of train_data and train_labels are returned in the __getitem__(self, index) function of your dataset class.