LSTM hidden[0] doesn't seem to pick up batch_first argument

I am trying to train an LSTM neural network. My data is of size (batch size, sequence length, features), so I have set “batch_first = True” when defining my LSTM class. I set a batch size of 30, hidden size of 200, and I am training a two layer bidirectional neural network. When I try to train it, however, I get “Expected hidden[0] size (4, 30, 200), got (30, 4, 200)” Size seems to be correct but it seems as though it is expecting a hidden size as if batch_first were set to False when in fact batch_first is set to True. What am I missing here?

Thanks for the help.

Please see my code below.

class lstm_net(nn.Module):

    def __init__(self, input_size, hidden_size, output_size, bias = True):
        super().__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers = 2, bidirectional = True, bias = bias, batch_first = True)
        self.linear = nn.Linear(hidden_size, output_size, bias = bias)

    def forward(self, input_seq, h_init, c_init):

        output_seq, (h_last, c_last) = self.lstm(input_seq, (h_init, c_init))
        scores = self.linear(output_seq)
        return scores

# Training loop 
net = lstm_net(2, 200, 2, bias = True)
bs = 30
lr = 1
criterion = nn.CrossEntropyLoss()

start = time.time()
for epoch in range(1, 11):

    # Learning schedule
    # TBD

    # Setup optimizer
    optimizer = optim.SGD(net.parameters(), lr = lr)

    # Initialize stats to zeros to track network's progress
    running_loss = 0
    running_error = 0
    num_batches = 0

    # Shuffle indices to randomize training
    shuffled_indices = torch.randperm(19481)

    for count in range(0, 19481 - bs, bs):

        # Initialize h and c to be zero
        h = torch.zeros(bs, 4, 200)
        c = torch.zeros(bs, 4, 200)

        # Detach prior gradient
        h = h.detach()
        c = c.detach()

        # Track changes
        h = h.requires_grad_()
        c = c.requires_grad_()

        # Set gradient to 0
        optimizer.zero_grad()

        # Make minibatch
        indices = shuffled_indices[count : count + bs]
        minibatch_data = train_data[indices]
        minibatch_label = train_label[indices]
        print(minibatch_data.size())

        # Track changes
        minibatch_data.requires_grad_()

        # Send minibatch through network
        scores, (h, c) = net(minibatch_data, h, c)

        # Compute loss of minibatch
        loss = criterion(scores, minibatch_label)

        # Backward pass
        loss.backward()

        # Do one step of stochastic gradient descent
        normalize_gradient(net)
        optimizer.step()

        # Update summary statistics
        with torch.no_grad():
            running_loss += loss.item()
            error = get_error(scores, minibatch_label)
            running_error += error
            num_batches += 1

    # At the end of each epoch, print summary statistics
    elapsed = time.time() - start
    avg_loss = running_loss / num_batches
    avg_error = running_error / num_batches
    print('| EPOCH {} |'.format(epoch))
    print('='*len('| EPOCH {} |'.format(epoch)))
    print('')
    print('Error: ', '{}%'.format(avg_error * 100), '\t Loss: ', avg_loss, '\t Time: ', '{} minutes'.format(elapsed / 60))

And the error I’m getting:

torch.Size([30, 90, 2])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-25-ba433b2f7ded> in <module>
     48 
     49         # Send minibatch through network
---> 50         scores, (h, c) = net(minibatch_data, h, c)
     51 
     52         # Compute loss of minibatch

~\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

<ipython-input-21-a7a31b056ff4> in forward(self, input_seq, h_init, c_init)
    194     def forward(self, input_seq, h_init, c_init):
    195 
--> 196         output_seq, (h_last, c_last) = self.lstm(input_seq, (h_init, c_init))
    197         scores = self.linear(output_seq)
    198         return scores

~\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

~\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\rnn.py in forward(self, input, hx)
    562             return self.forward_packed(input, hx)
    563         else:
--> 564             return self.forward_tensor(input, hx)
    565 
    566 

~\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\rnn.py in forward_tensor(self, input, hx)
    541         unsorted_indices = None
    542 
--> 543         output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
    544 
    545         return output, self.permute_hidden(hidden, unsorted_indices)

~\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\rnn.py in forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices)
    521             hx = self.permute_hidden(hx, sorted_indices)
    522 
--> 523         self.check_forward_args(input, hx, batch_sizes)
    524         if batch_sizes is None:
    525             result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias, self.num_layers,

~\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\rnn.py in check_forward_args(self, input, hidden, batch_sizes)
    498 
    499         self.check_hidden_size(hidden[0], expected_hidden_size,
--> 500                                'Expected hidden[0] size {}, got {}')
    501         self.check_hidden_size(hidden[1], expected_hidden_size,
    502                                'Expected hidden[1] size {}, got {}')

~\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\rnn.py in check_hidden_size(self, hx, expected_hidden_size, msg)
    164         # type: (Tensor, Tuple[int, int, int], str) -> None
    165         if hx.size() != expected_hidden_size:
--> 166             raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
    167 
    168     def check_forward_args(self, input, hidden, batch_sizes):

RuntimeError: Expected hidden[0] size (4, 30, 200), got (30, 4, 200)
1 Like

As far as I know, the value of batch_first only affects the input and output but not the hidden state.

I apologize if this response sounds naive, but what exactly is meant by hidden[0] and how can I change it? At least as far as I can tell, the LSTM is set up properly and the inputs are given with batch first, as would be expected with batch_first = True, so how can I change the hidden[0] to the correct dimensions? Seems like this shouldn’t be needed–with batch_first = True and an input that does indeed have the batch first, it seems like it should work. Still, I am not an expert in PyTorch so there may be an additional step with the inputs or setting up the LSTM class that I am missing.

I appreciate your help. Thanks.

I ended up figuring this out. After closer reading of the docs, I noticed that it does specify that with batch_first = True, only the input and output tensors are reported with batch first. The initial memory states (h_init and c_init) are still reported with batch second.

This seems like a bit of a confusing way to do this in my opinion–seems like setting batch_first = True should make everything batch first but at least I figured it out.

Thanks for your help.

1 Like

Yup, these are the little quirks one simply has to pick up along the road :). Happy coding!

I just came across the same problem. For the sake of consistency, I do agree. Hope Pytorch will update this in later versions.