I use multiprocessing and Distributed Data Parallel to train on multi gpus, my code:
for graphs in self.dataloader:
graphs = [self._data_to_gpu(data, self.device) for data in graphs]
graph_labels = torch.cat(list(map(lambda x: x.pop('label'), graphs)))
try:
embeddings = torch.cat(self.model(graphs))
except:
pass
else:
optimizer.zero_grad()
loss = criterion(embeddings, graph_labels, miner(embeddings, graph_labels) if miner is not None else None)
loss.backward()
if self.rank == 0:
total_loss.append(loss.item())
optimizer.step()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_norm)
if self.rank == 0:
bar.update()
I use the try…except… to catch the out of gpu memory error because of some big data. but it didn’t work. when code running into the except part, the whole training process would hang on.