Strange behavior during validation

Hi. There is a strange behavior of validation loader in my trainloop, and I can’t pinpoint what causes this. Basically, I’m doing standart trainloop with validation every N steps:

model.train()
    for epoch in range(max(0, last_epoch), training_epochs):
        for i, batch in enumerate(train_loader):

            audio_clean, audio_ns = batch

            audio_clean = audio_clean.squeeze(0)
            audio_ns = audio_ns.squeeze(0)

            mel_clean = torch.log(cumulative_laplace_norm(melspec(audio_clean))+1e-6).permute(0, 2, 1).to(device)
            mel_ns = torch.log(cumulative_laplace_norm(melspec(audio_ns))+1e-6).permute(0, 2, 1).to(device)

            # train step
            optim_g.zero_grad()
            mel_output = model(mel_ns)

            g_loss = mel_MSE(mel_clean, mel_output)
            g_loss.backward()

            optim_g.step()

            if rank == 0:
                # validation
                if steps % a.validation_interval == 0:
                    model.eval()
                    val_loss_tot_tot = 0
                    with torch.no_grad():
                        for j, batch in enumerate(validation_loader):
                            audio_clean, audio_ns = batch

                            mel_clean = torch.log(cumulative_laplace_norm(melspec(audio_clean)) + 1e-6).permute(0, 2, 1).to(device)
                            mel_ns = torch.log(cumulative_laplace_norm(melspec(audio_ns)) + 1e-6).permute(0, 2, 1).to(device)

                            mel_output = model(mel_ns)

                            val_loss = mel_MSE(mel_clean, mel_output)
                            val_loss_tot_tot += val_loss.item()

                        val_loss_tot = val_loss_tot_tot / (j + 1)

                    model.train()

            if scheduler_g:
                scheduler_g.step()

            steps += 1

Train sampler and validation sampler are identical (different datasets obviously, but with the same structure):

train_sampler = RandomSampler(trainset)
train_loader = DataLoader(
        trainset,
        num_workers=24,
        shuffle=False,
        sampler=train_sampler,
        batch_size=32,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=4,
        persistent_workers=True,
    )

val_sampler = RandomSampler(validset)
validation_loader = DataLoader(
          validset,
          num_workers=24,
          shuffle=False,
          sampler=val_sampler,
          batch_size=32,
          pin_memory=True,
          drop_last=True,
          prefetch_factor=4,
          persistent_workers=True,
      )

However, my validation behaves weirdly. It samples exactly num_workers batches from loader and then hangs for almost a minute, then samples another num_workers batches and hangs, and so on. This does not happen on train, sampling in train goes smoothly all the time. Changing num_workers only change how fast val loader comes to hanging. persistent_workers and prefetch_factor doesn’t affect this at all.

I’m NOT running DDP, this is running on single GPU.
I could blame the fact that I’m running several trainings on each gpu on same machine (CPU bottlenecking), but this happens even when single training is up.

I assume it’s not connected with .to(device), since loader hangs at the beginning of validation cycle.

Screenshot of GPU activity on validation from btop:

Worth noting that my dataloader is quite heavy and operates on CPU (I’m working with audio files and there are a bunch of augments that are not possible with GPU functions). But again, it’s completely fine on training phase.

Am I missing something or this is purely the problem with dataloader?

The validation has no backward, so the computation time can be less than the data loading time if the batch size is too small. Could you check the validation performance increasing the batch size?

Increasing batchsize from 32 to 64 doesn’t change perfomance at all; if anything, it made it slightly worse, since there are freezes in “periods of work”:

The following code doesn’t exist in the validation:

            audio_clean = audio_clean.squeeze(0)
            audio_ns = audio_ns.squeeze(0)

Is it correct?

Not really, but this exists just in case and in 99.9% of train samples just bypassed, since all my audio data is 1d. So I don’t really think it’s an issue here.

Could you create Nsight System profiles (with nvtx markers around the data loading, forward, backward, etc.) to narrow down which sections of the code might be the bottleneck?

To be honest, I have no idea what causes this

you can find whole file here, but it’s 1gb: google drive

And for some reason nvtx didn’t log getitem at all

If I understand the stacktrace and the enum_next correctly, the data loading loop is using this time so you might need to narrow down further if this data loading time is expected or not.

For some reason I can’t get nvtx to annotate my dataloader (except init). Decorators, context manager, ranges - none of them log anything from getitem for me. Maybe you can suggest something from the code:

    def __getitem__(self, index):
        rir_index_list = torch.randint(
            0, len(self.rir_list), (len(self.valid_seq_list),)
        )

        rir_list = [self.rir_list[i] for i in rir_index_list]

        noise_index_list = torch.randint(
            0, len(self.valid_ns_list), (len(self.valid_seq_list),)
        )
        valid_ns_list = [self.valid_ns_list[i] for i in noise_index_list]

        ##################################################################
        # clean speech
        spch_filename, seq_start, seq_end = self.valid_seq_list[index]
        spch_clean, fs_spch = torchaudio.load(spch_filename)
        assert self.fs == fs_spch
        spch_clean = spch_clean.squeeze()

        shift_max = min(seq_start - 0, spch_clean.shape[0] - seq_end, self.fs)
        shift = random.randint(-shift_max, shift_max)

        spch_clean = spch_clean[int(seq_start + shift): int(seq_end + shift)]
        ##################################################################

        ##################################################################
        # noise sampler

        ns_filename, seq_start, seq_end = valid_ns_list[index]
        ns, fs_ns = torchaudio.load(ns_filename)

        assert self.fs == fs_ns
        ns = ns.squeeze()
        if len(ns.shape) > 1:
            ns = ns[0]

        shift_max = min(seq_start - 0, ns.shape[0] - seq_end, self.fs)
        shift = random.randint(-shift_max, shift_max)

        noise = ns[int(seq_start + shift): int(seq_end + shift)]
        ##################################################################

        samples = int(self.fs * self.spch_len)

        rir, fs_rir = torchaudio.load(rir_list[index])
        assert self.fs == fs_rir
        rir = rir.squeeze()

        spch_reverb = torchaudio.functional.fftconvolve(spch_clean, rir, mode="full")
        spch_reverb = spch_reverb[:samples]
        spch_clean = spch_clean[:samples]
        noise = noise[:samples]

        snr = np.random.randint(self.snr_low, self.snr_high)
        snr_coeff = coef_by_snr(spch_clean, noise, snr)
        if np.isinf(snr_coeff):
            snr_coeff = 100
        spch_ns = spch_reverb + snr_coeff * noise

        # normalization
        gain = np.random.uniform(-1, -6)
        spch_ns, _ = torchaudio.sox_effects.apply_effects_tensor(spch_ns.unsqueeze(0), fs_spch, [["norm", f"{gain:.2f}"]])
        spch_clean, _ = torchaudio.sox_effects.apply_effects_tensor(spch_clean.unsqueeze(0), fs_spch, [["norm", f"{gain:.2f}"]])
        return spch_clean[0], spch_ns[0]

My main suspects are torchaudio.load, torchaudio.functional.fftconvolve and torchaudio.sox_effects.apply_effects_tensor, which of them could be the main bottleneck?

Okay, I just put time.time around all segments of getitem and found that for some reason torchaudio.load on noise files takes 3 seconds for a single file, I’ll try to figure out what’s wrong with my datasets. Thanks for suggestions!

Thanks for the update! For your previous question: you could try to use num_workers=0 and add nvtx markers as I guess the usage of multiple workers might cause the issues you were seeing before.