Expected hidden size different to actual hidden size

I am trying to train a BiLSTM and am encountering an error after the first training epoch.

Model

class BiLSTM(nn.Module):
  def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, tie_weights):
    super().__init__()

    self.num_layers = num_layers
    self.hidden_dim = hidden_dim
    self.embedding_dim = embedding_dim

    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, 
                        dropout=dropout_rate, batch_first=True, bidirectional=True)
    self.dropout = nn.Dropout(dropout_rate)
    self.linear = nn.Linear(hidden_dim*2, vocab_size) # update hidden_dim*2

    if tie_weights:
      # Embedding and hidden layer need to be same size for weight tying
      assert embedding_dim == hidden_dim, 'Cannot tie weights, check dimensions'
      self.linear.weight = self.embedding.weight

    self.init_weights()

  def forward(self, x, hidden):
    output  = self.embedding(x)
    output, hidden = self.lstm(output, hidden)
    output = self.dropout(output)
    output = self.linear(output)
    return output, hidden

  def init_weights(self):
    init_range_emb = 0.1
    init_range_other = 1/math.sqrt(self.hidden_dim)
    self.embedding.weight.data.uniform_(-init_range_emb, init_range_emb)
    self.linear.weight.data.uniform_(-init_range_other, init_range_other)
    self.linear.bias.data.zero_()


  def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers*2, batch_size, self.hidden_dim).to(device)
        cell = torch.zeros(self.num_layers*2, batch_size, self.hidden_dim).to(device)
        return hidden, cell
  
  def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach()
        cell = cell.detach()
        return hidden, cell

Model Call

vocab_size = len(vocab)
embedding_dim = 100
hidden_dim = 100
num_layers = 2
dropout_rate = 0.4
tie_weights = False
model = BiLSTM(vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, tie_weights)
model.to(device)

BiLSTM(
  (embedding): Embedding(9922, 100)
  (lstm): LSTM(100, 100, num_layers=2, batch_first=True, dropout=0.4, bidirectional=True)
  (dropout): Dropout(p=0.4, inplace=False)
  (linear): Linear(in_features=200, out_features=9922, bias=True)
)

Training

import copy
import time

criterion = nn.CrossEntropyLoss()
lr = 20.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()
    
    hidden = model.init_hidden(batch_size)
    
    num_batches = len(train_data) // bptt
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        hidden = model.detach_hidden(hidden)
        data, targets = get_batch(train_data, i)
        seq_len = data.size(0)
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, vocab_size), targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

Eval function

def evaluate(model: nn.Module, eval_data: Tensor) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    hidden = model.init_hidden(eval_batch_size)
    with torch.no_grad():
        for i in range(0, eval_data.size(0) - 1, bptt):
            hidden = model.detach_hidden(hidden)
            data, targets = get_batch(eval_data, i)
            seq_len = data.size(0)
            output, hidden = model(data, hidden)
            output_flat = output.view(-1, vocab_size)
            total_loss += seq_len * criterion(output_flat, targets).item()
    return total_loss / (len(eval_data) - 1)

Training Loop

best_val_loss = float('inf')
epochs = 50
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model)
    val_loss = evaluate(model, val_data)
    val_ppl = math.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)

    scheduler.step()

Error

RuntimeError                              Traceback (most recent call last)
<ipython-input-38-453c3f2a9cad> in <cell line: 5>()
      5 for epoch in range(1, epochs + 1):
      6     epoch_start_time = time.time()
----> 7     train(model)
      8     val_loss = evaluate(model, val_data)
      9     val_ppl = math.exp(val_loss)

6 frames
<ipython-input-37-16d7ac1074e2> in train(model)
     20         data, targets = get_batch(train_data, i)
     21         seq_len = data.size(0)
---> 22         output, hidden = model(data, hidden)
     23         loss = criterion(output.view(-1, vocab_size), targets)
     24 

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-17-b0382b88f83d> in forward(self, x, hidden)
     24     #cell_0 = torch.zeros(2*num_layers, batch_size, self.hidden_dim).to(device)
     25     output  = self.embedding(x)
---> 26     output, hidden = self.lstm(output, hidden)
     27     output = self.dropout(output)
     28     output = self.linear(output)

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
    808             hx = self.permute_hidden(hx, sorted_indices)
    809 
--> 810         self.check_forward_args(input, hx, batch_sizes)
    811         if batch_sizes is None:
    812             result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/rnn.py in check_forward_args(self, input, hidden, batch_sizes)
    729                            ):
    730         self.check_input(input, batch_sizes)
--> 731         self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
    732                                'Expected hidden[0] size {}, got {}')
    733         self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/rnn.py in check_hidden_size(self, hx, expected_hidden_size, msg)
    237                           msg: str = 'Expected hidden size {}, got {}') -> None:
    238         if hx.size() != expected_hidden_size:
--> 239             raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
    240 
    241     def _weights_have_changed(self):

RuntimeError: Expected hidden[0] size (4, 19, 100), got [4, 20, 100]

What I don’t understand is the batch_size is set to 20. So the tensor passed is [4, 20, 100] and the hidden is set as

hidden = torch.zeros(self.num_layers*2, batch_size, self.hidden_dim).to(device)

So it should just keep expecting tensors of shape [4, 20, 100]. I don’t know why it expects a different size.
Any help appreciated.

I guess the last batch might be smaller as the number of samples in the entire dataset divided by the batch size might result in a remainder.
I don’t know how you are initializing the DataLoader and if you are using one, but you could specify drop_last=True to remove this smaller batch at the end or you could initialize the hidden states with the actual batch size of the used input sample instead of the defined batch_size variable.