Data parallel with rnn

i developed a multilayer rnn with data parallel but i have an issue with the size of hidden state.

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=1, dropout=0.1):
        super(EncoderRNN, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = dropout
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=self.dropout, bidirectional=True, batch_first=False)
        
    def forward(self, input_seqs, input_lengths, hidden = None):
        # Note: we run this all at once (over multiple batches of multiple sequences)
        embedded = self.embedding(input_seqs)

        self.gru.flatten_parameters() 
        outputs, hidden = self.gru(embedded, hidden)
                
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] # Sum bidirectional outputs
        return outputs, hidden

This is my test function

small_batch_size = 3
input_batches, input_lengths, target_batches, target_lengths = random_batch(small_batch_size)

print('input_batches', input_batches.size()) # (max_len x batch_size) -> input_batches torch.Size([54, 3])
print('target_batches', target_batches.size()) # (max_len x batch_size) -> target_batches torch.Size([54, 3])

small_hidden_size = 8
small_n_layers = 2

encoder_test = EncoderRNN(sentence.n_words, small_hidden_size, small_n_layers)

encoder_test = torch.nn.DataParallel(encoder_test.cuda(), device_ids=[0,1])

This is the confuse part.

encoder_outputs, encoder_hidden = encoder_test(input_batches, input_lengths, None)

print('encoder_outputs', encoder_outputs.size()) # max_len x batch_size x hidden_size
print('encoder_hidden', encoder_hidden.size()) # n_layers * 2 x batch_size x hidden_size

When i print the size of encoder state i have encoder_hidden torch.Size([8, 3, 8]), but the correct size is [4,3,8] -> [ 2(n_layerz) * 2, 3= batch_size, 8=hidden_size]. if i use cuda without dataparallel i have the correct size.