I am trying to understand how the “N = batch size” option works for a LSTM (doc) and I find it a bit confusing.
Apparently, this works:
from torch.nn import Embedding, LSTM
num_chars = 8
batch_size = 2
embedding_dim = 3
hidden_size = 5
num_layers = 1
embed = Embedding(num_chars, embedding_dim)
lstm = LSTM(input_size=embedding_dim, hidden_size=hidden_size)
hiddens = torch.zeros(num_layers, 8 // batch_size, hidden_size)
cells = torch.zeros(num_layers, 8 // batch_size, hidden_size)
sequence = torch.tensor([0,1,2,3,4,5,6,7])
batched = sequence.view((batch_size, -1))
embedded = embed(batched)
last_hidden, (hiddens, cells) = lstm(embedded, (hiddens, cells))
The hidden state shape is (2, 4, 5) and the cell shape is (1, 4, 5) which sounds right.
But I had to use
8/batch_size when setting up the initial hidden and cell states, when the doc says
h_0 should be of shape (1,
batch_size, H_out). My
batch_size is 2, not 4.
What am I missing?