Does RNN work with DataParallel wrapper?

If I wrap RNN with DataParallel, it seems like output is not consistent with the target size. For instance, if the batch size is 32 and 2 gpus are active then 16 instances per gpu are processed. However, these instances should be aggregated in the end to get the whole batch of 32 instances for loss function. But when I use rnn, aggregation is not happening and model outputs only 16 instances which is conflicting with the target value size.

I don’t know it makes sense ?

you can look at to see how to wrap RNN in DataParallel

You have to check whether you are using batchFirst for your RNN and also which dimension is being scattered by DataParallel

This is strange. If custom RNN module returns the both RNN state and the output, DataParallel does not work and the problem above appears. If you just return output of RNN then things are fine