output = nn.parallel.parallel_apply(net, input_)
predict = [torch.max(output[i], -1)[1] for i in devices]
correct += sum([(predict[i] == label_[i]).sum().item() for i in devices])
loss = nn.parallel.parallel_apply(criterion, list(zip(output, label_)))
for idx in range(len(loss)):
loss[idx] = loss[idx].unsqueeze(0)
loss_ = torch.sum(nn.parallel.gather(loss, target_device=devices[-2]))
optimizer.zero_grad()
loss_.backward(retain_graph=True)
optimizer.step()
running_loss += float(loss_)
print ("epoch [%d] round_within_load [%d] accuracy: %.3f total loss: %.3f" %(epoch, round, correct*1.0/total, running_loss) )
Above is a piece of code. the first several rounds were okay, after after 50 rounds loss[0] got huge and loss[1:4] were very reasonable. training exactly the same data on one gpu, train accuracy is 95% and valid 93%. this parallel script doesn’t converge at all.