Multi layer RNN with DataParallel

Here is one solution for using nn.DataParallel I find works well. It can return both rnn output and hidden states from your module, using batch_first = False mode (which is a popular mode).

  1. Use batch_first = False for seq2seq modeling (e.g., encoder-decoder architecture), which seems to the be the popular way for feeding the input data, so we get input batch with shape (max_sequence_length, batch_size, num_embeddings).

  2. Then set dim=1 to scatter the data the along the dimension for batches when calling torch.nn.DataParallel to wrap the model. Every input argument to your model has to have dimension one corresponding to the batches. So if you also pass sequence lengths for using the pack-unpack trick, normally it is a tensor of shape (batch_size,). In order to scatter the data along dimension one, need to call .squeeze(0) to make it of shape (1, batch_size). If an input argument is a scalar value (or python float), it is OK to use it as it is since it broadcastable along dimension 1.

  3. I also find that if the module’s forward function has an argument with default value, then it won’t work.

  4. To use pack-unpack trick, need to pay attention to one caveat when calling pad_packed_sequence. We need to make sure all results have the same shape, otherwise the gathering will fail. In particular, remember to set totoal_length=max sequence length in pad_packed_sequence call. See https://pytorch.org/docs/master/notes/faq.html#pack-rnn-unpack-with-data-parallelism

Note the example in the link uses batch_first=True, since it does not return hidden states. If you are using batch_first=False, get max sequence length from the dimension 0 of padded_input (padded_input.shape[0]).