DataParallel LSTM/GRU wrong hidden batch size (8 GPUs)

I’m confused about how to use DataParallel properly over multiple GPU’s because it seems like it’s distributing along the wrong dimension (code works fine using only single GPU).

The model using dim=0 in Dataparallel, batch_size=32 and 8 GPUs is:

import torch
import torch.nn as nn
from torch.autograd import Variable

class StepRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers): #
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.encoder = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.GRU(input_size=hidden_size, \
                                hidden_size=hidden_size,\
                                num_layers=num_layers)
        self.decoder = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden):
        batch_size = input.size(0)
        encoded = self.encoder(input)
        output, hidden = self.rnn(encoded.view(1, batch_size, -1), hidden)
        output = self.decoder(output.view(batch_size, -1))
        return output, hidden

    def init_hidden(self, batch_size):
        return Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size))


decoder = StepRNN(
    input_size=100,
    hidden_size=64,
    output_size=100,
    num_layers=1)

decoder_dist = nn.DataParallel(decoder, device_ids=[0,1,2,3,4,5,6,7], dim=0)
decoder_dist.cuda()

batch_size = 32
hidden = decoder.init_hidden(batch_size).cuda()
input_ = Variable(torch.LongTensor(batch_size, 10)).cuda()
target =  Variable(torch.LongTensor(batch_size, 10)).cuda()

for c in range(10):
     decoder_dist(input_[:,c].contiguous(), hidden) #RuntimeError: Expected hidden size (1, 4, 64), got (1, 32, 64)

The result is RuntimeError: Expected hidden size (1, 4, 64), got (1, 32, 64). It makes sense that its expecting a 32/8 hidden size but it seems to be passing the full batch. What am I missing? Full traceback here.

With dim=1 I get RuntimeError: invalid argument 2: out of range. Full trace here.

Interestingly, if I open an Ipython session and run the code once, I get the runtime error above. But, if I run it again unchanged, I get a different error: RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1502009910772/work/torch/lib/THC/generic/THCTensorCopy.c:18. This seems pretty consistent but not sure why the error would change with the exact same code.

I found another question where the issue is related to batch_first=True so taking dim=0 by default doesn’t work. But I’m using the default batch_first=False.

1 Like

if it’s a nn.GRU, i think you have to use the flag batch_first=True to make sure the inputs are interpreted to be having mini-batch in dimension-0

http://pytorch.org/docs/master/nn.html?highlight=batch_first#torch.nn.GRU

1 Like

I changed my code to use an LSTM like so:

import torch
import torch.nn as nn
from torch.autograd import Variable

class StepRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers): #
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.encoder = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.LSTM(input_size=hidden_size, \
                                hidden_size=hidden_size,\
                                num_layers=num_layers)
        self.decoder = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden):
        batch_size = input.size(0)
        encoded = self.encoder(input)
        output, hidden = self.rnn(encoded.view(1, batch_size, -1), hidden)
        output = self.decoder(output.view(batch_size, -1))
        return output, hidden

    def init_hidden(self, batch_size):
        return (Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)).cuda(),
                Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size).cuda()))


decoder = StepRNN(
    input_size=100,
    hidden_size=8,
    output_size=100,
    num_layers=1)

decoder_dist = nn.DataParallel(decoder, device_ids=[0,1,2,3,4,5,6,7], dim=0)
decoder_dist.cuda()

batch_size = 16
hidden = decoder.init_hidden(batch_size)
input_ = Variable(torch.LongTensor(batch_size, 10)).cuda()
target =  Variable(torch.LongTensor(batch_size, 10)).cuda()

for c in range(10):
     decoder_dist(input_[:,c].contiguous(), hidden)

The result is again RuntimeError: Expected hidden size (1, 2, 8), got (1, 16, 8). Full trace. It doesn’t seem to affect GRU only so I modified the title of this post for future possible searches.

What is the right way to parallelize consistent with the pytorch defaults? It seems like DataParallel is expecting data in the non standard way, or am I missing anything?

As a beginner its confusing to rethink everything I’ve learned using batch_first=True , how can I go about using DataParallel using the defaults, or how would I have to modify the code above to use batch_first=True?

Thanks for any input.

1 Like

@jekbradbury any pointers for using RNNs + DataParallel?

2 Likes

I’ve traced the source of the error to cudnn/rnn.py during the forward pass that starts in line 190. The error comes from this part of the code:

190    def forward(fn, input, hx, weight, output, hy):
191        with torch.cuda.device_of(input):

        (...)

264            if tuple(hx.size()) != hidden_size:
265                raise RuntimeError('Expected hidden size {}, got {}'.format(
266                    hidden_size, tuple(hx.size())))

where looks like hidden_size is the properly formatted data, and hx is the actual data being passed to each GPU (there’s a similar snippet in line 370 but I think this is the relevant one because its during forward pass).

The offending object is hx, which is passed to forward as an argument, so it looks like forward expects hx to be properly split already by some previous process?

I had run into the same problem when trying to combine rnn modules with DataParallel.

If you wrap the rnn in a module where the forward function only requires an input parameter, it works fine. It doesn’t seem to work when you need both the input and hidden parameter. If you can contain your hidden parameter logic within your module, it’s an effective work-around.

Thanks, not sure I follow, what do you mean by containing hidden parameter inside the module?

Glad to hear I’m not the only one having this annoying issue :wink:

1 Like

Something like this (not working code):

class LSTM(nn.Module):
    def __init__(self, initial_state):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(
            ...
            batch_first=True)
        self.hn = initial_state
            
    def forward(self, input):
        output, hn = self.lstm(input, self.hn)
        self.hn = hn
        return output

Thanks for the snippet, I see what you meant. Seems a bit similar to other responses here I’ve seen where people actually have to use the non-default batch_first=True. I would definitely try to help debugging but this issue is a bit over my head.

@jekbradbury Will batch_first=True be the standard way of using DataParallel with RNNs going on? Haven’t been able to use it with this setting, but if this will be the standard way moving forward I might as well learn how to do it :slight_smile: It would also be nice to have it documented somewhere. Thanks.

I think both batch-first and batch-second modes are compatible with DataParallel (it assumes batch-first by default, since that’s true of all non-RNN-related tensors, but it has a keyword argument to split over a different dimension). Both modes are definitely compatible with the rest of the RNN infrastructure, including pack_padded_sequence.

DataParallel is not working for me over multiple GPUs with batch_first=False, and I think there are other questions in the forum with similar issues iirc. The two snippets I posted above (GRU and LSTM) will not work with multiple GPUs even when splitting on a different dimension with batch_first=False (I made the snippets self-contained to make it easy to verify). It seems from other questions here that batch_first=True works fine, but I don’t think it works with False unless my code is wrong --which is entirely possible. If you have a minute I’d appreciate the validation of the code as I’m learning pytorch and can’t say for sure.

1 Like

What’s the conclusion? I have the same issue. Should we use batch_first=True? I want to use DataParallel with batch_first=False.

How about this?

Just use BatchFirst is False, but got the batch data as B x S. Then, you just need to transpose once in your RNN cell. Then, that’s it.

def forward(self, input, seq_lengths):
    # Note: we run this all at once (over the whole input sequence)
    # input shape: B x S (input size)
    # transpose to make S(sequence) x B (batch)
    input = input.t()
    batch_size = input.size(1)

Then, everything is OK. For the entire code, please check out at https://github.com/hunkim/PyTorchZeroToAll/blob/master/12_4_name_classify.py.

1 Like

I’m using batch_first = True, and my forward function requires just one parameter. I’m still facing errors. Here’s my model.

class CryptoLSTM(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, batch_size, vocab_size):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.hidden2text = nn.Linear(hidden_dim, vocab_size)
        self.hidden = self.init_hidden()

    def init_hidden(self):
        self.hidden = (torch.autograd.Variable(torch.zeros(1, self.batch_size,
            self.hidden_dim).cuda()), torch.autograd.Variable(torch.zeros(
            1, self.batch_size, self.hidden_dim).cuda()))

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        tag_space = self.hidden2text(lstm_out)
        scores = F.log_softmax(tag_space, dim=2)
        return scores

Here’s the training script

    model = CryptoLSTM(args.embedding_dim, args.hidden_dim,
                       args.batch_size, len(alphabet))
    model = torch.nn.DataParallel(model).cuda()
    loss_function = nn.NLLLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    for epoch in range(args.num_epochs):
        for batch in dataloader:
            model.zero_grad()
            model.module.init_hidden()

            inputs, targets = batch
            predictions = model(inputs)
            # predictions.size() == 64x128x27
            # NLLLoss expects classes to be the second dim
            predictions = predictions.transpose(1, 2)
            # predictions.size() == 64x27x128
            loss = loss_function(predictions, targets)
            loss.backward()
            optimizer.step()

And here’s the Traceback.

Traceback (most recent call last):
  File "train.py", line 244, in <module>
    main()
  File "train.py", line 186, in main
    predictions = model(inputs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 73, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 83, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/parallel_apply.py", line 67, in parallel_apply
    raise output
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/parallel_apply.py", line 42, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "train.py", line 132, in forward
    lstm_out, self.hidden = self.lstm(embeds, self.hidden)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/rnn.py", line 190, in forward
    self.check_forward_args(input, hx, batch_sizes)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/rnn.py", line 158, in check_forward_args
    'Expected hidden[0] size {}, got {}')
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/rnn.py", line 154, in check_hidden_size
    raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
RuntimeError: Expected hidden[0] size (1, 32, 2000), got (1, 64, 2000)

I’m using torch 0.3.1 on Python 3.5. Any help would be greatly appreciated.

I am using data_parallel with the GRU and batch_first = true, I dont think the GRU’s hidden state is outputted they way the documentations says it is supposed to.

Facing the same problem. Any conclusions now?

Same issues here. Trying to use DataParallel with a LSTM model.
RuntimeError: Expected hidden[0] size (1, 2500, 50), got (1, 10000, 50)

I can see from the shape mismatch what the general problem is. The hidden is being created for the entire model input (10000 in my case) where dataparallel is dividing that input by GPU count (4 in my case) to spread the load. Maybe we can also wrap the hidden input tensor with dataparallel so its also distributed correctly?

I might have found a workaround for this issue, or maybe its the actual correct way to implement. According to the torch.nn.LSTM docs
“If (h_0, c_0) is not provided, both h_0 and c_0 default to zero.”

So the workaround is basically to allow nn.LSTM to initialize itself rather than have separate init_hidden logic. Some might say this is the correct way to initialize the hidden. Thoughts?

I think the reason why DataParallel didn’t work when you need both the input and hidden parameter is that h_0 shape is (num_layers * num_directions, batch, hidden_size) whether the batch_first is True or False.

See the description:

    """
    Args:
        input_size: The number of expected features in the input `x`
        hidden_size: The number of features in the hidden state `h`
        num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
            would mean stacking two LSTMs together to form a `stacked LSTM`,
            with the second LSTM taking in outputs of the first LSTM and
            computing the final results. Default: 1
        bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
            Default: ``True``
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False``
        dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
            LSTM layer except the last layer, with dropout probability equal to
            :attr:`dropout`. Default: 0
        bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``

    Inputs: input, (h_0, c_0)
        - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
          of the input sequence.
          The input can also be a packed variable length sequence.
          See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
          :func:`torch.nn.utils.rnn.pack_sequence` for details.
        - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
          containing the initial hidden state for each element in the batch.
          If the RNN is bidirectional, num_directions should be 2, else it should be 1.
        - **c_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
          containing the initial cell state for each element in the batch.

          If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
    """

batch_first only decide the input sizes, not the hidden, which causes the problem. Maybe you can try using batch_first = False and change the input sizes to apply that.

Again I tried using batch_first = True in LSTM. It turns out that it works if you provide hidden as (num_layers * num_directions, batch, hidden_size) or don’t specify the hidden like @robd2 said, either way will work.

1 Like