DDP training process hanging on when some subprocess didn't compute loss

I use multiprocessing and Distributed Data Parallel to train on multi gpus, my code:

 for graphs in self.dataloader:
                graphs = [self._data_to_gpu(data, self.device) for data in graphs]
                graph_labels = torch.cat(list(map(lambda x: x.pop('label'), graphs)))
                try:
                    embeddings = torch.cat(self.model(graphs))
                except:
                    pass
                else:
                    optimizer.zero_grad()
                    loss = criterion(embeddings, graph_labels, miner(embeddings, graph_labels) if miner is not None else None)
                    loss.backward()
                    if self.rank == 0:
                        total_loss.append(loss.item())
                    optimizer.step()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_norm)
                if self.rank == 0:
                    bar.update()

I use the try…except… to catch the out of gpu memory error because of some big data. but it didn’t work. when code running into the except part, the whole training process would hang on.

If you are dealing with use cases where not all parameters are used in the forward pass, find_unused_parameters might help and avoid running into hangs. However, I’m unsure if this would/should also work if the loss is not created, since this particular model wouldn’t be able to contribute to the current training step at all.