Multi layer RNN with DataParallel

@Varg_Nord I’ve tested approach with hidden state permutations. For me it’s faster than using batch_first=False and using necessary permutations in .forward. Now, correct RNNLM with DataParallel looks following:

class Model(nn.Module):
    def __init__(self, ntokens=100000, nx=300, nhid=600,
                 nlayers=3, dropout=0.5):
        super(Model, self).__init__()

        self.keep_prob = dropout
        self.ntokens = ntokens
        self.nx = nx
        self.nhid = nhid
        self.nlayers = nlayers

        self.dropout = nn.Dropout(p=self.keep_prob)

        self.embs = nn.Embedding(num_embeddings=self.ntokens,
                                 embedding_dim=self.nx,
                                 padding_idx=constants.PAD)

        self.rnn = nn.LSTM(input_size=self.nx,
                           hidden_size=self.nhid,
                           num_layers=self.nlayers,
                           batch_first=True,
                           dropout=self.keep_prob)

        self.linear = nn.Linear(in_features=self.nhid,
                                out_features=self.ntokens)

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embs.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.fill_(0)
        self.linear.weight.data.uniform_(-initrange, initrange)

    def init_hidden(self, batch_size=96):
        return [
            Variable(torch.zeros(batch_size, self.nlayers, self.nhid)).cuda(),
            Variable(torch.zeros(batch_size, self.nlayers, self.nhid)).cuda(),
        ]

    def forward(self, x, maxlen, conds, hidden):

        for i in range(len(hidden)):
            hidden[i] = hidden[i].permute(1, 0, 2).contiguous()

        lengths = x.ne(constants.PAD).sum(dim=1).data.cpu().view(-1).numpy()

        embs = self.embs(x)
        embs = pack(embs, lengths, batch_first=True)

        output, hidden = self.rnn(embs, hidden)
        output = unpack(output, batch_first=True)[0]
        output = self.dropout(output)

        padded_output = Variable(
            torch.zeros(output.size()[0], maxlen, output.size()[2])
        ).cuda()

        padded_output[:, :max(lengths), :] = output

        decoded = self.linear(
            padded_output.view(
                padded_output.size(0) * padded_output.size(1),
                padded_output.size(2)
            )
        )

        hidden = list(hidden)
        for i in range(len(hidden)):
            hidden[i] = hidden[i].permute(1, 0, 2).contiguous()

        return decoded, hidden

As you can see, I initialise hidden states with dimensions, where batch dimension is first, then I provide hidden state to .forward and permute them to fit correct dimensions for torch.nn.LSTM, where batch dimension is second. And then permute them back for correct DataParallel gather. So it works very fast and correctly.

net = torch.nn.DataParallel(Model(ntokens, nx, nhid, nlayers, dropout).cuda(), dim=0) hidden = net.module.init_hidden(batch_size) output, hidden = net(input, hidden)

And if you check dimensions of hidden states and inputs in .forward, they will be correct.

4 Likes