LSTM input shape with "batches"

Hello!

I am trying to understand how the “N = batch size” option works for a LSTM (doc) and I find it a bit confusing.

Apparently, this works:

import torch
from torch.nn import Embedding, LSTM

num_chars = 8
batch_size = 2
embedding_dim = 3
hidden_size = 5
num_layers = 1

embed = Embedding(num_chars, embedding_dim)
lstm = LSTM(input_size=embedding_dim, hidden_size=hidden_size)

hiddens = torch.zeros(num_layers, 8 // batch_size, hidden_size)
cells = torch.zeros(num_layers, 8 // batch_size, hidden_size)

sequence = torch.tensor([0,1,2,3,4,5,6,7])

batched = sequence.view((batch_size, -1))

embedded = embed(batched)
last_hidden, (hiddens, cells) = lstm(embedded, (hiddens, cells))

The hidden state shape is (2, 4, 5) and the cell shape is (1, 4, 5) which sounds right.
But I had to use 8/batch_size when setting up the initial hidden and cell states, when the doc says h_0 should be of shape (1, batch_size, H_out). My batch_size is 2, not 4.

What am I missing?

Your embedded activation has a shape of [2, 4, 3] corresponding to [batch_size, seq_len, embedding_dim]. However, nn.LSTM expects inputs in [seq_len, batch_size, input_size] by default.
Thus, dim1 with a size of 4 is used as the batch size raising the error.
You should permute the embedded activation or use batch_first=True in the nn.LSTM module:

hiddens = torch.zeros(num_layers, batch_size, hidden_size)
cells = torch.zeros(num_layers, batch_size, hidden_size)

sequence = torch.tensor([0,1,2,3,4,5,6,7])

batched = sequence.view((batch_size, -1))

embedded = embed(batched).permute(1, 0, 2).contiguous()
last_hidden, (hiddens, cells) = lstm(embedded, (hiddens, cells))
embed = Embedding(num_chars, embedding_dim)
lstm = LSTM(input_size=embedding_dim, hidden_size=hidden_size, batch_first=True)

hiddens = torch.zeros(num_layers, batch_size, hidden_size)
cells = torch.zeros(num_layers, batch_size, hidden_size)

sequence = torch.tensor([0,1,2,3,4,5,6,7])

batched = sequence.view((batch_size, -1))

embedded = embed(batched)
last_hidden, (hiddens, cells) = lstm(embedded, (hiddens, cells))

Actually, the code runs perfectly well - I’m just wondering if the doc is correct about N. Or maybe I got lucky it runs, although I have the wrong setup? But I don’t see an error on my side.

For example, with no error raised, I see last_hidden come out as having shape (2,4,5).

Oh, and one more question: when using batches that way with an LSTM, is it simply equivalent to running multiple separate LSTM’s in parallel? The hiddens and cells take on one extra dimension, which seems to suggest there are just batch_size LSTM’s running in parallel independently.

Basically, it is equivalent to multiple instances running in parallel. Albeit, there are not multiple copies of the model made when performing the forward pass with a batch size > 1.

Generally speaking, a neural network will train more efficiently and reduce overfitting by accumulating the gradients over different examples before back propagation.

Additionally, the higher the batch size, the higher the learning rate you can use without getting exploding gradients.

And, lastly, running batches in parallel allows you to take advantage of multiprocessing in parallel, which is especially beneficial if you’re using a GPU.

Yes, the docs are correct and you are lucky by changing the shape randomly executing a wrong setup. Check my code snippet and print all shapes to understand where the issue is.

Got it, makes sense. Yeah, I just got lucky that the mixed-up dimensions were just still working by chance. Thanks!

I’m confused - am I training 1 LSTM or 4 if my batch_size is 4? Are the internal W and gate matrices shared between the 4 batches, or separate?

“Shared”. I use that word very loosely here. What really happens is no different than any other network with a batch size > 1.

The loss function determines what occurs with autograd on the batch dim. For instance, L1Loss has a reduction parameter whose default is “mean”. That means the mean of the gradients across the batch dim is what is computed for that loss function configuration. See here:

https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html

You calculate the loss prior to backpropagation. Nothing with the parameters change until backprop(i.e. loss.backward() and optimizer.step()).

Yeah, ok, makes sense, but when backprop takes place with .backward, is it going to update as many internal W matrices in the LSTM as there are batches? Or just one, no matter what the batch size is?

Just one. There is only one copy of the parameters. The gradients get averaged over the batch and sequence dims for each data sampling given to the model.