Multi layer RNN with DataParallel

I found that when using DataParallel with multi layer RNN module, c0 and h0 would be split so size error raises.

1 Like

Hello @Varg_Nord !
In my experiments, DataParallel works correctly with multi layer RNN. Example is provided below:

class RNNLM(nn.Module):

    def __init__(self, ntokens=100000, nx=300, nhid=600,
                 nlayers=3, dropout=0.5):
        super(Model, self).__init__()

        self.keep_prob = dropout
        self.ntokens = ntokens
        self.nx = nx
        self.nhid = nhid
        self.nlayers = nlayers

        self.dropout = nn.Dropout(p=self.keep_prob)

        self.embs = nn.Embedding(num_embeddings=self.ntokens,
                                 embedding_dim=self.nx,
                                 padding_idx=constants.PAD)

        self.rnn = nn.LSTM(input_size=self.nx,
                           hidden_size=self.nhid,
                           num_layers=self.nlayers,
                           batch_first=True,
                           dropout=self.keep_prob)

        self.linear = nn.Linear(in_features=self.nhid,
                                out_features=self.ntokens)

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embs.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.fill_(0)
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, maxlen, conds, hidden=None):

        lengths = x.ne(constants.PAD).sum(dim=1).data.cpu().view((-1)).numpy()

        embs = self.embs(x)
        embs = pack(embs, lengths, batch_first=True)

        output, hidden = self.rnn(embs, hidden)
        output = unpack(output, batch_first=True)[0]
        output = self.dropout(output)

        padded_output = Variable(
            torch.zeros(output.size()[0], maxlen, output.size()[2])
        ).cuda()

        padded_output[:, :max(lengths), :] = output

        decoded = self.linear(
            padded_output.view(
                padded_output.size(0) * padded_output.size(1),
                padded_output.size(2)
            )
        )

        return decoded, hidden

net = torch.nn.DataParallel(Model(ntokens=nvocab, nx=300, nhid=300, nlayers=3, dropout=0.5).cuda())

Also, don’t forget to set env for CUDA_VISIBLE_DEVICES.

If you will have some error with this code then provide the error and your setup.

Here is my minimum example:

import torch.nn as nn
import torch
import torch.nn.init as weight_init

class Network(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Network, self).__init__()
        self.lstm = nn.LSTM(
            input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True)
        for p in self.lstm.parameters():
            weight_init.normal(p, std=0.1)

    def forward(self, input_var, h0):
        output, ht = self.lstm(input_var, h0)
        return output,ht

net = Network(256,256)
net.cuda()
dp=torch.nn.DataParallel(net)
input_var=torch.autograd.Variable(torch.rand(1,32,256).cuda())
h0=torch.autograd.Variable(torch.randn(2,1,256).cuda())
c0=torch.autograd.Variable(torch.randn(2,1,256).cuda())
h=(h0,c0)

out, ht=dp(input_var,h)

Then error is RuntimeError: Expected hidden size (2, 1L, 256), got (1L, 1L, 256L). I believe it is because pytorch cut input tensor into pieces and send them to different gpus, but I think that hidden state should not be split. As for your example, I will simplify and try it.

@Varg_Nord How much GPUs do you use?

I have only two GPUs

@Varg_Nord I found the problem. If batch_first=True is used, then DataParallel with default parameter dim=0 will split input_var and h0 in first dimension. It is correct for the input_var, but not for h0, because rnn hidden states always have dimension is equal to num_layers * num_directions x batch_size x hidden_size. For easy solution you can use batch_first=False, so the second dimension will correspond to batch dimension in input_var and use DataParallel with dim=1, so it will split input_var and h0 using correct dimension for both variables.
If you want to use batch_first=True, then you can swap axes for hidden states before .forward and swap back inside .forward. But it is computationally inefficient.

6 Likes

Thank you so much! I’ve tried your example and got really confused by the difference. Thank you for point out what I missed:)

@apaszke @smth Hello!
As we understood in this topic DataParallel does not work correct with RNN hidden states, when batch_first=True. It is because batch_first option affects only input and output of RNN, but not hidden states. Maybe it’s better to add the following feature: if batch_first=True then input and output hidden states of rnn will have following dimension (batch, num_layers * num_directions, hidden_size) instead of (num_layers * num_directions, batch, hidden_size), which is now fixed for both True and False options of batch_first. So it will be easier to work with hidden states without boilerplate permutations everytime. And it will be consistent with input and output dimensions.

Another way is to expand DataParallel functionality and use different dimensions for each neural network input. Now it uses the same split dimension for each type of input.

4 Likes

@Varg_Nord I’ve tested approach with hidden state permutations. For me it’s faster than using batch_first=False and using necessary permutations in .forward. Now, correct RNNLM with DataParallel looks following:

class Model(nn.Module):
    def __init__(self, ntokens=100000, nx=300, nhid=600,
                 nlayers=3, dropout=0.5):
        super(Model, self).__init__()

        self.keep_prob = dropout
        self.ntokens = ntokens
        self.nx = nx
        self.nhid = nhid
        self.nlayers = nlayers

        self.dropout = nn.Dropout(p=self.keep_prob)

        self.embs = nn.Embedding(num_embeddings=self.ntokens,
                                 embedding_dim=self.nx,
                                 padding_idx=constants.PAD)

        self.rnn = nn.LSTM(input_size=self.nx,
                           hidden_size=self.nhid,
                           num_layers=self.nlayers,
                           batch_first=True,
                           dropout=self.keep_prob)

        self.linear = nn.Linear(in_features=self.nhid,
                                out_features=self.ntokens)

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embs.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.fill_(0)
        self.linear.weight.data.uniform_(-initrange, initrange)

    def init_hidden(self, batch_size=96):
        return [
            Variable(torch.zeros(batch_size, self.nlayers, self.nhid)).cuda(),
            Variable(torch.zeros(batch_size, self.nlayers, self.nhid)).cuda(),
        ]

    def forward(self, x, maxlen, conds, hidden):

        for i in range(len(hidden)):
            hidden[i] = hidden[i].permute(1, 0, 2).contiguous()

        lengths = x.ne(constants.PAD).sum(dim=1).data.cpu().view(-1).numpy()

        embs = self.embs(x)
        embs = pack(embs, lengths, batch_first=True)

        output, hidden = self.rnn(embs, hidden)
        output = unpack(output, batch_first=True)[0]
        output = self.dropout(output)

        padded_output = Variable(
            torch.zeros(output.size()[0], maxlen, output.size()[2])
        ).cuda()

        padded_output[:, :max(lengths), :] = output

        decoded = self.linear(
            padded_output.view(
                padded_output.size(0) * padded_output.size(1),
                padded_output.size(2)
            )
        )

        hidden = list(hidden)
        for i in range(len(hidden)):
            hidden[i] = hidden[i].permute(1, 0, 2).contiguous()

        return decoded, hidden

As you can see, I initialise hidden states with dimensions, where batch dimension is first, then I provide hidden state to .forward and permute them to fit correct dimensions for torch.nn.LSTM, where batch dimension is second. And then permute them back for correct DataParallel gather. So it works very fast and correctly.

net = torch.nn.DataParallel(Model(ntokens, nx, nhid, nlayers, dropout).cuda(), dim=0) hidden = net.module.init_hidden(batch_size) output, hidden = net(input, hidden)

And if you check dimensions of hidden states and inputs in .forward, they will be correct.

4 Likes

Here is one solution for using nn.DataParallel I find works well. It can return both rnn output and hidden states from your module, using batch_first = False mode (which is a popular mode).

  1. Use batch_first = False for seq2seq modeling (e.g., encoder-decoder architecture), which seems to the be the popular way for feeding the input data, so we get input batch with shape (max_sequence_length, batch_size, num_embeddings).

  2. Then set dim=1 to scatter the data the along the dimension for batches when calling torch.nn.DataParallel to wrap the model. Every input argument to your model has to have dimension one corresponding to the batches. So if you also pass sequence lengths for using the pack-unpack trick, normally it is a tensor of shape (batch_size,). In order to scatter the data along dimension one, need to call .squeeze(0) to make it of shape (1, batch_size). If an input argument is a scalar value (or python float), it is OK to use it as it is since it broadcastable along dimension 1.

  3. I also find that if the module’s forward function has an argument with default value, then it won’t work.

  4. To use pack-unpack trick, need to pay attention to one caveat when calling pad_packed_sequence. We need to make sure all results have the same shape, otherwise the gathering will fail. In particular, remember to set totoal_length=max sequence length in pad_packed_sequence call. See https://pytorch.org/docs/master/notes/faq.html#pack-rnn-unpack-with-data-parallelism

Note the example in the link uses batch_first=True, since it does not return hidden states. If you are using batch_first=False, get max sequence length from the dimension 0 of padded_input (padded_input.shape[0]).

You may find this gist helpful:
How to use PyTorch DataParallel to train LSTM on charcters

PyTorch 1.5 has completely fixed the issues with RNN training and DataParallel.
It seems it has done so quite seamlessly.
No more gerrymandering being required
I confirmed this today, in a project involving bidir GRUS on speech mfccs.