Increase speed of loading RAW images


I’m trying to reproduce results from paper Learning to see in the dark this repo using Pytorch. I have referred to implementations from this repo and this repo but I had some issues.
As far as I know, these repos load RAW images in two ways:

  • Use PyTorch’s dataloader, read RAW image using rawpy package in custom getitem. In this way, I can use all the image’s pairs for training and validation but speed is extremely low because of these file are quite heavy. 1 epoch took me about 4200 secs to finish. (cydonia999’s implementation)
  • Load all images into a dictionary/array variable. Basically images are stored in RAM so this way require lots of RAM memory. Images only need to be loaded once. After that, traning phase and validation phase is done very quickly (about 30s-45s/1 epoch on V100 GPU).

I prefer the second method than the first one because 2000 epochs with 1 epoch takes more than 1 hour could be a really long time. But due to the limited RAM I have using Colab Pro (25GB), I can not load all the images. Now, I’m trying to load half of the images, and random loading again after 100 epoch. Could this be the most suitable way for me? Or is there any way to speed up loading process using dataloader?

Best regards,

Some ideas here:
What’s the image sizes? Can you resample them before hand?
Have you considered using numpy arrays saving the array directly?¿
Have u tried to use numpy memory map?

Image’s shape is (1424, 2128, 4) with np.uint16.
Currently I’m using numpy arrays to hold the images by allocating like this:

X = np.zeros((m_x, 1424, 2128, 4), dtype=np.uint16) 
Y = np.zeros((m_y, 1424, 2128, 3), dtype=np.uint16)

And then replace the zero value with the image’s value.

What is the benefit of using memory map?

Thank you for replying me!

Hmm you can create a big numpy array with all the images stacked together and then just read those images/fragments that you need. You save the required time to instantiate the array and so.

Another option you have is going for uint8 (would be reading half of the data)

Apart of that there is not much you can do.
You may explore DALI library NVIDIA Developer Data Loading Library (DALI)
which has a dataloader faster than pytorch’s but it’s not as simple to use.

Thank you for suggestion!