Being more of an NLP person and dealing regulary with LSTMs or GRUs – but this is a general issue, I think – I’ve noticed that many people make a fundamental mistake. I’ve seen it many Github projects I’ve tried to reproduce but also here in the forum (usually something like: “My network runs but does not train/learn properly”, even for arguably simple networks).
When using LSTMs or GRUs, the input and/or output do not have the right shape. For example, the input batch has the shape
[batch_size, seq_len, hidden_size], but an LSTM without
batch_first=True assumes an input of shape
[seq_len, batch_size, hidden_size]. Many people incorrectly use
reshape() to fix the shape. While it does fix the shape, it messes up the data and essentially prohibits proper training (e.g., the loss is not going down).
The correct way here is to use
permute() to swap dimensions.
To illustrate the problem, let’s create an example tensor with a
seq_len dimension; I omit the
hidden_size dimension to keep things simple
batch_size, seq_len = 3, 5 A = torch.zeros((batch_size, seq_len)) A[0,0:] = 1 A[1,0:] = 2 A[2,0:] = 3
This gives a tensor with shape
[batch_size, seq_len] looking like that:
tensor([[1., 1., 1., 1., 1.], [2., 2., 2., 2., 2.], [3., 3., 3., 3., 3.]])
Now let’s say you have defined an
nn.LSTM layer with
batch_first=False. In this case, you need two swap the two dimensions of A to make an a valid input for the LSTM layer (again, the
hidden_size dimension is omitted here). You can do this with the following commands:
A1 = a.transpose(1,0) A2 = a.permute(1,0)
If you print either A1 or A1 you get the correct result:
tensor([[1., 2., 3.], [1., 2., 3.], [1., 2., 3.], [1., 2., 3.], [1., 2., 3.]])
However, if you use
A3 = a.view(seq_len, -1) A4 = a.reshape(seq_len, -1)
you get the wrong result:
tensor([[1., 1., 1.], [1., 1., 2.], [2., 2., 2.], [2., 3., 3.], [3., 3., 3.]])
Note that for
A4 the shape is
[5, 3]. So with respect to the shape all results are correct, i.e.,
[seq_len, batch_size], the LSTM layer will accept each result without throwing any error.
A4 have messed up the batch semantically. The network might still learn something – after all, there’s usually a pattern to be found in everything – but it will learn the wrong pattern.
reshape() obviously have their purpose, for example, to flatten tensors. But for swapping dimensions they are generally the wrong methods. I hope this avoids some head-scratching, particularly for beginners. And I can point to this post when I see the next question that most likely suffer from that issue :).