How to use nn.torch.data_parallel for LSTM

I tried using data_parallel for LSTM

input = #(50, 99, 100)
h0 = #(4, 50,500)
c0 = #(4, 50,500)

encoder = nn.LSTM(100, 500,2,bidirectional=True)
output, (h_t, c_t) = nn.parallel.data_parallel(encoder, (input, (h0, c0)), device_ids=[0,1])

The hidden dimensions are (nlayer*directions, batch_size, hidden_size) unlike input size which is (batch_size, seq_length, embed_size). For parallel, the first dimension of all the inputs need to be batch_size which is not true for h0 and c0 hence the error above. How can I solve this issue?

Thanks!