I notice that the training time of first epoch is some kind longer than the following epochs by using the code of tensorflow version. I guess it is because tensorflow has cached all batchs of data in the first epoch, then the processing of the following epochs can be very quick?
Can I reimplement this operation with pytorch? It may reduce the training time in total.
In the case of PyTorch, if you have enough memory, you could move the whole dataset tensor to the GPU device:
device = ...
dataset_tensor = dataset_tensor.to(device) # works for batches too, just need to loop over them
Then, start your training loop and obtain a lower training time. Anyway, the dataset batches still have to be loaded in the memory. Either way, the total execution time for the script may probably be a little lower only.