Chunked loss produces error with DistributedDataParallel

I am having a problem with chunked loss calculation when using DistributedDataParallel (the code works fine on a single GPU). I use a single node with 4 GPU’s, and am training a transformer model for NLP. Instead of feeding the whole batch to the final linear layer that maps to the vocabulary dimension (called generator in the code below), I split the batch up in chunks. This is common practice, see e.g. http://nlp.seas.harvard.edu/2018/04/03/attention.html (class MultiGPULossCompute).

The code I use for loss calculation (where x are the model activations):

x_copy = x.clone().detach()
x_copy.requires_grad = True

chunk_loss_all = 0.0
for chunk_start in range(0, batch_size, chunk_size):
	# Calculate loss per chunk
	chunk_end = min(chunk_start + chunk_size, batch_size)
	chunk_predictions = generator(x_copy[chunk_start:chunk_end])
        chunk_loss = criterion(chunk_predictions.contiguous().view(-1, chunk_predictions.size(-1)),
        		                              y[chunk_start:chunk_end].contiguous().view(-1))
	chunk_loss_all += chunk_loss

# backward for chunk losses
chunk_loss_all.backward()

# backward through rest of the model
x_gradients = x_copy.grad.view_as(x)
x.backward(gradient=x_gradients)
optimizer.step()
optimizer.zero_grad()

The error that is produced:

  File "loss/compute.py", line 75, in chunked_loss
    x.backward(gradient=x_gradients)
  File "/home/dstap1/anaconda3/envs/logos/lib/python3.8/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/dstap1/anaconda3/envs/logos/lib/python3.8/site-packages/torch/autograd/__init__.py", line 97, in backward
    Variable._execution_engine.run_backward(
RuntimeError: has_marked_unused_parameters_ INTERNAL ASSERT FAILED at /opt/conda/conda-bld/pytorch_1579022027550/work/torch/csrc/distributed/c10d/reducer.cpp:290, please report a bug to PyTorch.

The problem is probably in the multiple .backward() calls. I don’t know how to rewrite my code to solve this problem. Any ideas?