Multi GPU GRU AssertionError

I wrap my module with DataParallel. I use

device_ids=[2, 3]

when forward, the gru part will raise AssertionError

class Multi(nn.Module):
    def __init__(self):
        super(Multi, self).__init__()
        self.encoder_gru = nn.GRU(input_size=4, hidden_size=2,
                                  num_layers=1, bidirectional=True)

    def forward(self, input,hidden):
        input = input.transpose(0, 1)
        hidden = hidden.transpose(0, 1)
        gru_output, h_n = self.encoder_gru(input, hidden)
        return gru_output.transpose(0,1)  # (seq_len, batch, hidden_size * num_directions)

data = Variable(torch.ones(6,2,4))
cuda_data = data.cuda(2)
model0 = Multi()
a = Variable(torch.ones(6,2,2))*2

model = nn.DataParallel(model0, device_ids=[2, 3])
model.cuda(2)
output = model(cuda_data, a.cuda(2))
loss = output.sum()
loss.backward()
print output
print data.grad
for name, param in model0.named_parameters():
    print name, param.grad

error info

Traceback (most recent call last):
  File "/home/ryk/programming/kbp_torch/test.py", line 223, in <module>
output = model(cuda_data, a.cuda(2))
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 206, in __call__
result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/data_parallel.py", line 61, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/data_parallel.py", line 71, in parallel_apply
return parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/parallel_apply.py", line 46, in parallel_apply
raise output
AssertionError

Process finished with exit code 1

oh, i found the reason.
input and output must be contiguous. But i don’t know why. In single gpu case, they don’t need to be congituous

I encountered the same problem when I call rnn.flatten_parameters() before calling rnn(inputs). Removing flatten_parameters() solves the problem.