Chunked loss produces error with DistributedDataParallel

I am having a problem with chunked loss calculation when using DistributedDataParallel (the code works fine on a single GPU). I use a single node with 4 GPU’s, and am training a transformer model for NLP. Instead of feeding the whole batch to the final linear layer that maps to the vocabulary dimension (called generator in the code below), I split the batch up in chunks. This is common practice, see e.g. (class MultiGPULossCompute).

The code I use for loss calculation (where x are the model activations):

x_copy = x.clone().detach()
x_copy.requires_grad = True

chunk_loss_all = 0.0
for chunk_start in range(0, batch_size, chunk_size):
	# Calculate loss per chunk
	chunk_end = min(chunk_start + chunk_size, batch_size)
	chunk_predictions = generator(x_copy[chunk_start:chunk_end])
        chunk_loss = criterion(chunk_predictions.contiguous().view(-1, chunk_predictions.size(-1)),
	chunk_loss_all += chunk_loss

# backward for chunk losses

# backward through rest of the model
x_gradients = x_copy.grad.view_as(x)

The error that is produced:

  File "loss/", line 75, in chunked_loss
  File "/home/dstap1/anaconda3/envs/logos/lib/python3.8/site-packages/torch/", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/dstap1/anaconda3/envs/logos/lib/python3.8/site-packages/torch/autograd/", line 97, in backward
RuntimeError: has_marked_unused_parameters_ INTERNAL ASSERT FAILED at /opt/conda/conda-bld/pytorch_1579022027550/work/torch/csrc/distributed/c10d/reducer.cpp:290, please report a bug to PyTorch.

The problem is probably in the multiple .backward() calls. I don’t know how to rewrite my code to solve this problem. Any ideas?