Hello!
I am trying to understand how the “N = batch size” option works for a LSTM (doc) and I find it a bit confusing.
Apparently, this works:
import torch
from torch.nn import Embedding, LSTMnum_chars = 8
batch_size = 2
embedding_dim = 3
hidden_size = 5
num_layers = 1embed = Embedding(num_chars, embedding_dim)
lstm = LSTM(input_size=embedding_dim, hidden_size=hidden_size)hiddens = torch.zeros(num_layers, 8 // batch_size, hidden_size)
cells = torch.zeros(num_layers, 8 // batch_size, hidden_size)sequence = torch.tensor([0,1,2,3,4,5,6,7])
batched = sequence.view((batch_size, -1))
embedded = embed(batched)
last_hidden, (hiddens, cells) = lstm(embedded, (hiddens, cells))
The hidden state shape is (2, 4, 5) and the cell shape is (1, 4, 5) which sounds right.
But I had to use 8/batch_size
when setting up the initial hidden and cell states, when the doc says h_0
should be of shape (1, batch_size
, H_out). My batch_size
is 2, not 4.
What am I missing?