Iterable Dataset Multithreaded

Hello there,

I want to make custom dataset or dataloader, just don’t know which is the best way to do it.
This is a self supervised task, where you want to blank out part of the input and then use it as a “label”.

Let’s say I have .npy files, each one of shape (600,30000), and I want to do the following:

  • Read the file,
  • pick a row from the 600,
  • take a slice around that row, say (32,30000),
  • return y=x[row] and x where x[row]=0 (blanked out).
  • The next batch should be another x,y pair where you just pick another row, and then repeat this for all the rows of the file, then move on to the net file.

After thoroughly reading the docs I thought that the best way to do that, is to make an Iterable Dataset, but whenever I try to increase the batch_size e.g. 56, in order to get a batch of (56,32,30000), this takes too long ~12’’ for a single batch which is an eternity of waiting for the GPU.

Code for the iterable dataset:

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, data_path, N_sub, batch_size, channel_min = 1700, channel_max = 2300):
        self.data_path = data_path
        self.filenames = [x for x in os.listdir(data_path) if x.endswith(".npy")]
        self.channel_min = channel_min
        self.channel_max = channel_max
        self.N_sub = N_sub
        self.batch_size = batch_size
    
    def sliding_window(self):
        for file in self.filenames:
            data = np.load(f"{data_path}{file}",mmap_mode='r')[self.channel_min:self.channel_max]
            for row in range(data.shape[0]):
                low_index = int(row - self.N_sub/2)
                high_index = int(row + self.N_sub/2)
                # if target is close to zero, then pick range [0, N_sub], target is not centered.
                if int(row - self.N_sub/2) <= 0:
                    low_index = 0
                    high_index = self.N_sub
                # if target is close to max channel, pick range [Nsub, max_channel], target is not centered again.
                if int(row + self.N_sub/2) >= self.channel_max-self.channel_min:
                    high_index = self.channel_max-self.channel_min
                    low_index = (self.channel_max - self.channel_min) - self.N_sub

                # Normalization, this causes minimal data leakage
                data = data/data.std()
                # Copy because assigning values to slices messes things up
                y_ = data[row].copy()

                # Zero out the target channel, the model will predict this.
                data[row] = 0
                x_ = data[low_index:high_index]
                # Keep only frequencies from f_min to f_max.
                # x_ = taper_filter(x_, self.f_min, self.f_max, self.sampleRate)
                # y_ = taper_filter(y_, self.f_min, self.f_max, self.sampleRate)
                x_ = torch.tensor(x_.astype(np.float32).copy())
                y_ = torch.tensor(y_.astype(np.float32).copy())
                yield x_, y_
        
    def __iter__(self):
        return self.sliding_window()

    def __len__(self):
        return len(self.filenames)

I thought of parallelizing this, I just can’t find an example related to this, so:

  1. Any code/examples/advice you see fit is welcome,
  2. Is this the way to go? Or am I overcomplicating things?

Thanks a lot in advance,
Have a great day!

You can try multi-process data loading and you may find the example 2 here helpful.

1 Like

What you can do is provide multiple worker to DataLoader. And, to prevent duplicate data across each process, you can give each Dataset instance within each worker part of your self.filenames based on worker_id.

1 Like

Thanks for the reply. I’m still trying to implement this to work on my problem.
One question though, if I move to multi-GPU in the future, will something like that be scalable?

Yes, you can use the torch.distributed package. You can find more details on this page.

We recommend you to follow the instructions of “3. Use single-machine multi-GPU DistributedDataParallel” on that page.