DataParallel LSTM/GRU wrong hidden batch size (8 GPUs)

I’m using batch_first = True, and my forward function requires just one parameter. I’m still facing errors. Here’s my model.

class CryptoLSTM(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, batch_size, vocab_size):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.hidden2text = nn.Linear(hidden_dim, vocab_size)
        self.hidden = self.init_hidden()

    def init_hidden(self):
        self.hidden = (torch.autograd.Variable(torch.zeros(1, self.batch_size,
            self.hidden_dim).cuda()), torch.autograd.Variable(torch.zeros(
            1, self.batch_size, self.hidden_dim).cuda()))

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        tag_space = self.hidden2text(lstm_out)
        scores = F.log_softmax(tag_space, dim=2)
        return scores

Here’s the training script

    model = CryptoLSTM(args.embedding_dim, args.hidden_dim,
                       args.batch_size, len(alphabet))
    model = torch.nn.DataParallel(model).cuda()
    loss_function = nn.NLLLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    for epoch in range(args.num_epochs):
        for batch in dataloader:
            model.zero_grad()
            model.module.init_hidden()

            inputs, targets = batch
            predictions = model(inputs)
            # predictions.size() == 64x128x27
            # NLLLoss expects classes to be the second dim
            predictions = predictions.transpose(1, 2)
            # predictions.size() == 64x27x128
            loss = loss_function(predictions, targets)
            loss.backward()
            optimizer.step()

And here’s the Traceback.

Traceback (most recent call last):
  File "train.py", line 244, in <module>
    main()
  File "train.py", line 186, in main
    predictions = model(inputs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 73, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 83, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/parallel_apply.py", line 67, in parallel_apply
    raise output
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/parallel_apply.py", line 42, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "train.py", line 132, in forward
    lstm_out, self.hidden = self.lstm(embeds, self.hidden)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/rnn.py", line 190, in forward
    self.check_forward_args(input, hx, batch_sizes)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/rnn.py", line 158, in check_forward_args
    'Expected hidden[0] size {}, got {}')
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/rnn.py", line 154, in check_hidden_size
    raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
RuntimeError: Expected hidden[0] size (1, 32, 2000), got (1, 64, 2000)

I’m using torch 0.3.1 on Python 3.5. Any help would be greatly appreciated.