DataParallel Multi-GPU stuck on backward pass

Hello! I’m dealing with an issue on PyTorch 1.1 where I am training a model on time series data. The data is held in a list of Tensors, where each tensor can be split into multiple batches for parallelization.

When the batch size is 1, only one of two GPUs are utilized (which is expected), and the model trains smoothly. However, when batch size > 1 (this is not the mini_batch size variable, but the number of batches that the model is allowed to use for parallelization) the program gets stuck at backpropagation (line loss.backward()). When debugging, the loss variable is identical (except the float value itself) in its properties both in multi-batch and single-batch cases (screenshot attached below), so I can’t figure out what is going wrong.

I’ve found out that PyTorch used to have problems with Tesla K80 GPUs, but I also tried out with Tesla P100s and the problem persists, so it doesn’t seem to be the issue. I’ve also looked at pretty much all related posts in here, but none seem to address/solve the issue I am currently having. My training code is attached below for reference.

Any help is much appreciated. Thank you!

    def train_minibatch(self, model, optimizer, mini_batch=10):
        model.train()

        train_loss_value = 0
        predictions = []
        targets = []

        for batch_idx, ts in enumerate(self.train_loader):

            batch_loss_value = 0
            batch_predictions = []
            batch_targets = []

            optimizer.zero_grad()

            for i in range(len(ts)):

                output = model(ts[i].to(self.device))
                target = torch.argmax(cues[:, i], dim=1)
                prediction = output.max(1, keepdim=True)

                torch.cuda.synchronize()

                loss = sum([F.nll_loss(output[j, :].unsqueeze(0), target[j].unsqueeze(-1), weight=self.w) for j in range(len(target))])
                loss = loss / mini_batch
                loss.backward()

                if i > 0 and i % mini_batch == 0:
                    optimizer.step()
                    optimizer.zero_grad()

                train_loss_value += loss.item()
                predictions.extend(prediction.tolist())
                targets.extend(target.tolist())

                batch_loss_value += loss.item()
                batch_predictions.extend(prediction.tolist())
                batch_targets.extend(target.tolist())

            self.print_eval(batch_loss_value, batch_predictions, batch_targets, idx=batch_idx, header='batch idx:')