Best practices when reading a large number of files every dataloader iteration

I’m training a CNN that classifies small regions (7x7) from 40 large input layers that are all used to do the classification. So the CNN takes in a (40, 7, 7) input to be classified.
But the 40 input layers themselves though are very large (over 10k x 10k pixels), and so I can’t read all 40 files in the dataset constructor because that goes beyond the amount of RAM I have.

So instead, I’m forced to read all 40 files during every __getitem__ call of the dataset loader, and just read the desired 7x7 location inside the input layers for this iteration.

But this has made my training very slow as I think it’s just taking a long time to open and read the windows from all 40 files every iteration.

I am already using multiple workers

torch.utils.data.DataLoader(dset, batch_size=32, num_workers=4)
class MyDataset(Dataset):
  def __init__(self):
    self.all_input_layers = ["file1", "file2", ..., "file40"]
  def __getitem__(self, idx):
    all_layer_data = []
    for (file in self.all_input_layers):
      curr_window = read_file_window(file) # returns 7x7 np array
      all_layer_data.append(curr_window)

    data = np.concatenate(all_layer_data, axis=0) # 40 x 7 x 7 data that will be inputted into CNN

    return data

What are the other things I should try to speed this up?

  • Should I use torch.multiprocessing.set_sharing_strategy('file_system')?
    • I don’t have any issues with too many file descriptors being open, which is the case that using this ‘file_system’ strategy seems to be recommended for, so not sure if it would help me.
  • What else?

You could profile the method to read the data as well as create the windows from these tensors and check where most of the time is spent.
Once you’ve isolated it, you could try to accelerate the code (e.g. with a 3rd party library if possible) or think about changing the overall data loading (e.g. would it be possible to store the data in another format and only load the desired window instead of the whole data array).
For a more general advice have a look at this post.

1 Like