Training rnn with detached states

I have a rnn model (as described below) and I don’t know how to train it efficiently under dataparallel.

Given a batch of training samples, the model runs computations for 5 cycles. The output of each cycle is used to calculate a loss and to update model parameters. Some outputs of a cycle is detached and sent to the next cycle as inputs.

I can write these as the following pseudo codes:

for cycle in range(total_cycles):
y, outputs = model(x, inputs) # inside model outputs are already detached, but y is not detached
loss = loss_function(target, y)
inputs = outputs # outputs is a list of feature maps

In other words, the model parameters are updated across different cycles. In each cycle, the model gets a list of detached states from previous cycle to build a new computational graph and calculates/updates with its gradients.

This scheme works fine with single GPU. However, when I trained it with multiple gpus and dataparallel, I found that each output tensor was always sent to gpu-0 (checked with outputs[0].get_device()), resulting in a significantly higher memory-usage of gpu-0 than other gpus. I cannot train a large enough model and meanwhile the gpu-util is low.

I would like to get outputs data-paralleled and calculated within each gpu, to improve the memory/time efficiency of training.