Tensorboard SummaryWriter add_audio + DDP not working

Hi,

I’m having issues with SummaryWriter.add_audio() when using DDP. Only when I use one GPU do all images get logged, if not they are processed but don’t appear in the tensorboard’s events file output.

I think the problem is that the operation isn’t thread-safe or atomic. I’ve tried by forcing a flush() but didn’t work.

Here is the relevant pice of code

        for logger in self.trainer.logger:
            if isinstance(logger, TensorBoardLogger):

                speaker_ids = (
                    speaker_ids if type(speaker_ids) is list else [speaker_ids]
                )
                utt_ids = utt_ids if type(utt_ids) is list else [utt_ids]

                for i, (speaker_id, utt_id) in enumerate(zip(speaker_ids, utt_ids)):
                    if run_type == "Test" or (
                        run_type == "Validation" and utt_id in self.trainer.examples
                    ):
                        print(f"Device: {self.device} Spk: {speaker_id} Utt: {utt_id}")

                        if self.global_step == 0:
                            logger.experiment.add_audio(
                                f"Speaker_{speaker_id}/{utt_id}_original",
                                y_audio[i].squeeze().detach(),
                                self.global_step,
                                self.trainer.mel_spec.sample_rate,
                            )
                        logger.experiment.add_audio(
                            f"Speaker_{speaker_id}/{utt_id}_generated",
                            y_hat_audio[i].squeeze().detach(),
                            self.global_step,
                            self.trainer.mel_spec.sample_rate,
                        )
                logger.experiment.flush()

This is the output:

INFO:trainer.dataset:Validation set has 600 examples
Validation sanity check:  23%|███████████████████████████████████████████████████▊                                                                                                                                                                          | 14/60 [00:04<00:08,  5.17it/s]
Device: cuda:4 Spk:     0 Utt: LJ012-0149
Validation sanity check:  33%|██████████████████████████████████████████████████████████████████████████                                                                                                                                                    | 20/60 [00:06<00:09,  4.44it/s]
Device: cuda:0 Spk:     8 Utt: dartagnan01_24_dumas_0149
Validation sanity check:  35%|█████████████████████████████████████████████████████████████████████████████▋                                                                                                                                                | 21/60 [00:06<00:08,  4.39it/s]
Device: cuda:2 Spk:     8 Utt: dartagnan01_24_dumas_0149
Validation sanity check:  38%|█████████████████████████████████████████████████████████████████████████████████████                                                                                                                                         | 23/60 [00:06<00:07,  5.22it/s]
Device: cuda:3 Spk:     1 Utt: littleminister_20_barrie_0001
Validation sanity check:  65%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                             | 39/60 [00:09<00:03,  6.47it/s]
Device: cuda:0 Spk:     4 Utt: bambatse_21_haggard_0013
Validation sanity check:  78%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                | 47/60 [00:10<00:02,  6.03it/s]
Device: cuda:6 Spk:     2 Utt: widowbarnaby_25_trollope_0117
Validation sanity check:  83%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                     | 50/60 [00:11<00:01,  6.23it/s]
Device: cuda:3 Spk:     4 Utt: bambatse_21_haggard_0013
Validation sanity check:  88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                          | 53/60 [00:11<00:01,  6.50it/s]
Device: cuda:1 Spk:     4 Utt: bambatse_21_haggard_0013
Validation sanity check:  92%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                  | 55/60 [00:12<00:00,  5.06it/s]
Device: cuda:3 Spk:     3 Utt: bigbluesoldier_04_hill_0187
Validation sanity check: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:14<00:00,  2.76it/s]Device: cuda:8 Spk:     7 Utt: annualreportseducation_11_mann_0203
Device: cuda:7 Spk:     5 Utt: internationalshortstories1_08_patten_0564
Device: cuda:3 Spk:     7 Utt: annualreportseducation_11_mann_0203
Device: cuda:2 Spk:     7 Utt: annualreportseducation_11_mann_0203
``

Not sure how you set up your training script, but the common idiom is to use rank 0 for TensorBoard logging; usually after sync’ing workers either implicitly (e.g. start/end of epoch) or explicitly (barrier() call)

I’m using PytorchLightning which takes care of most of the configuration code.
I know that is usual to just log rank0 outputs but I want to log one audio per speaker or specific audios in the dev set. Hence, I need to save audios processed on different GPUs.

I thought that workers sync’ing was for example to combine all models losses (one per GPU) and run backpropagation on the total amount.

I will check how to synchronize the logger.