enumerate(DataLoader) #_MultiProcessing DataLoaderIter Extremely Slow

Hi,

I have developed an audio-visual facial reenactment solution, and I have tested my code with several normal-size datasets which works perfectly, however, I am experiencing major issue with regards to speed when I try to use a high-resolution dataset. Specifically, I am now trying to use a large (2.6TB) high-resolution audio-visual dataset that contain videos with a resolution of 192x192. All works well with the low resolution datasets, however, in the case of the high-resolution data, the dataloader now takes close to 40 seconds to produce 1 batch, and I am not doing anything computationally expensive to the best of my knowledge. All that I am doing is to load 5 frames, resize to 192x192, normalize, and produce a batch with size [B, 3, 5, 96, 192].
When I profiled by code using pprofile, it shows that the following line in the threading.py script consumes the most time.

File: /home/XXX/anaconda3/envs/msc/lib/python3.9/threading.py
   312|       571|      24675.3|      43.2142|1545.50%|                waiter.acquire()

When I profile the code using the PyTorch profiler, I get that the following code consumes the most amount of time.

Name                                                        Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...        86.28%      673.132s        86.29%      673.143s        2.693s

I have read numerous posts on MultiProcessingDataLoaderIter, however, none seemed to be relevant to my issue. From my understanding, this seems to relate to a threading issue, therefore, I tried changing the num_workers parameter in the DataLoader, however, this did not improve the result.

class The_Dataset(Dataset):  
    def __init__(self, root_path, split):
        super(The_Dataset, self).__init__()
        self.all_videos = get_dataset_seg(root_path, split)
        self.img_transform = transforms.Compose([transforms.Resize((config.img_size, config.img_size)), transforms.ToTensor()])
        self.audio_tranform = torch.nn.Sequential(torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=config.n_fft, win_length=config.win_size, hop_length=config.hop_size, f_min=config.fmin, f_max=config.fmax, n_mels=80, power=1, norm='slaney', mel_scale='slaney'),
                            torchaudio.transforms.AmplitudeToDB('amplitude')).to("cpu")

    def __len__(self):
        return len(self.all_videos)

    def __getitem__(self, idx):  
        while 1:
            idx = np.random.randint(0, len(self))  # Low is inclusive, high is exclusive
            vidname = self.all_videos[idx]
            vid_frames = glob(join(vidname, "*.jpg"))
            vid_frames = sorted(vid_frames)

            real_frame_idx = np.random.randint(0, len(vid_frames) - 5)
            real_frame = vid_frames[real_frame_idx]

            fake_frame_idx = np.random.randint(0, len(vid_frames) - 5)
            while fake_frame_idx == real_frame_idx or abs(real_frame_idx - fake_frame_idx) > 50:  # Allows for a max shift of 2 seconds! Beyond this, the network does not learn anything
                fake_frame_idx = np.random.randint(0, len(vid_frames) - 5)
            fake_frame = vid_frames[fake_frame_idx]  # After this, it is impossible for the real index to equal to the chosen fake index

            if np.random.choice([True, False]):
                chosen = real_frame
                y = torch.ones(1).float()
            else:
                chosen = fake_frame
                y = torch.zeros(1).float()


            window_images,_ = load_window(chosen, self.img_transform, config.hori_flip)  # Will be BGR
            if window_images is None:
                # warnings.warn("Broken Window")
                continue

            try:
                wavpath = join(vidname, "audio.wav")
                wav2 = torch.squeeze(final_audio.load_wav(wavpath))
                orig_mel2 = final_audio.my_spectrogram(wav2, self.audio_tranform, config.preemphasis).T
                mel2 = crop_audio_window(orig_mel2.clone(), real_frame)
                if mel2.shape[0] != syncnet_Ta:
                    continue
            except:
                # warnings.warn("Broken Audio")
                continue
            window = torch.stack(window_images,dim = 1)
            window = window[:,:,window.shape[2] // 2:, :]

            mel = torch.unsqueeze(mel2.T, dim=0)
            return window, mel, y

I have also verified that this is not an issue with my model by passing tensors with random values with the desired size and in that case, training speed is back to normal because the data loader was not used.
I presume the issue relates to the way that the items are being stacked to form a batch because when I time the getitem function, the time is <0.02s.

Any assistance would be highly appreciated.