I had and am having the same problem as you, using 2048x2048 images on Google Colab and analyzing the training process I noticed that the bottleneck is actually fetching the data from GDrive.
Things I have adopted temporarily are:
- set the device globally :
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.device(device)
- follow this rule to calculate the num_workers:
NUM_WORKERS = 4 * os.cpu_count()
I believe, however, that to get maximum performance, you need to load all the data into memory beforehand, but I couldn’t find anything online.
To speed up loading, an HDF5 dataset should be used and then a data loader wrapper created.
Here is some useful information: