@Varg_Nord I’ve tested approach with hidden state permutations. For me it’s faster than using batch_first=False
and using necessary permutations in .forward
. Now, correct RNNLM with DataParallel
looks following:
class Model(nn.Module):
def __init__(self, ntokens=100000, nx=300, nhid=600,
nlayers=3, dropout=0.5):
super(Model, self).__init__()
self.keep_prob = dropout
self.ntokens = ntokens
self.nx = nx
self.nhid = nhid
self.nlayers = nlayers
self.dropout = nn.Dropout(p=self.keep_prob)
self.embs = nn.Embedding(num_embeddings=self.ntokens,
embedding_dim=self.nx,
padding_idx=constants.PAD)
self.rnn = nn.LSTM(input_size=self.nx,
hidden_size=self.nhid,
num_layers=self.nlayers,
batch_first=True,
dropout=self.keep_prob)
self.linear = nn.Linear(in_features=self.nhid,
out_features=self.ntokens)
self.init_weights()
def init_weights(self):
initrange = 0.1
self.embs.weight.data.uniform_(-initrange, initrange)
self.linear.bias.data.fill_(0)
self.linear.weight.data.uniform_(-initrange, initrange)
def init_hidden(self, batch_size=96):
return [
Variable(torch.zeros(batch_size, self.nlayers, self.nhid)).cuda(),
Variable(torch.zeros(batch_size, self.nlayers, self.nhid)).cuda(),
]
def forward(self, x, maxlen, conds, hidden):
for i in range(len(hidden)):
hidden[i] = hidden[i].permute(1, 0, 2).contiguous()
lengths = x.ne(constants.PAD).sum(dim=1).data.cpu().view(-1).numpy()
embs = self.embs(x)
embs = pack(embs, lengths, batch_first=True)
output, hidden = self.rnn(embs, hidden)
output = unpack(output, batch_first=True)[0]
output = self.dropout(output)
padded_output = Variable(
torch.zeros(output.size()[0], maxlen, output.size()[2])
).cuda()
padded_output[:, :max(lengths), :] = output
decoded = self.linear(
padded_output.view(
padded_output.size(0) * padded_output.size(1),
padded_output.size(2)
)
)
hidden = list(hidden)
for i in range(len(hidden)):
hidden[i] = hidden[i].permute(1, 0, 2).contiguous()
return decoded, hidden
As you can see, I initialise hidden states with dimensions, where batch dimension is first, then I provide hidden state to .forward
and permute them to fit correct dimensions for torch.nn.LSTM
, where batch dimension is second. And then permute them back for correct DataParallel
gather. So it works very fast and correctly.
net = torch.nn.DataParallel(Model(ntokens, nx, nhid, nlayers, dropout).cuda(), dim=0) hidden = net.module.init_hidden(batch_size) output, hidden = net(input, hidden)
And if you check dimensions of hidden states and inputs in .forward
, they will be correct.