Loss flying away on first gpu but stayed reasonable on the rest

            output = nn.parallel.parallel_apply(net, input_)
            predict = [torch.max(output[i], -1)[1] for i in devices]
            correct += sum([(predict[i] == label_[i]).sum().item() for i in devices])
            loss = nn.parallel.parallel_apply(criterion, list(zip(output, label_)))
            for idx in range(len(loss)):
                loss[idx] = loss[idx].unsqueeze(0)
            loss_ = torch.sum(nn.parallel.gather(loss, target_device=devices[-2]))
            optimizer.zero_grad()
            loss_.backward(retain_graph=True)
            optimizer.step()
            running_loss += float(loss_)
            print ("epoch [%d] round_within_load [%d] accuracy: %.3f total loss: %.3f" %(epoch, round, correct*1.0/total, running_loss) )

Above is a piece of code. the first several rounds were okay, after after 50 rounds loss[0] got huge and loss[1:4] were very reasonable. training exactly the same data on one gpu, train accuracy is 95% and valid 93%. this parallel script doesn’t converge at all.