Being more of an NLP person and dealing regulary with LSTMs or GRUs – but this is a general issue, I think – I’ve noticed that many people make a fundamental mistake. I’ve seen it many Github projects I’ve tried to reproduce but also here in the forum (usually something like: “My network runs but does not train/learn properly”, even for arguably simple networks).
When using LSTMs or GRUs, the input and/or output do not have the right shape. For example, the input batch has the shape [batch_size, seq_len, hidden_size]
, but an LSTM without batch_first=True
assumes an input of shape [seq_len, batch_size, hidden_size]
. Many people incorrectly use view()
or reshape()
to fix the shape. While it does fix the shape, it messes up the data and essentially prohibits proper training (e.g., the loss is not going down).
The correct way here is to use transpose()
or permute()
to swap dimensions.
To illustrate the problem, let’s create an example tensor with a batch_size
and seq_len
dimension; I omit the hidden_size
dimension to keep things simple
batch_size, seq_len = 3, 5
A = torch.zeros((batch_size, seq_len))
A[0,0:] = 1
A[1,0:] = 2
A[2,0:] = 3
This gives a tensor with shape [batch_size, seq_len]
looking like that:
tensor([[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.]])
Now let’s say you have defined an nn.LSTM
layer with batch_first=False
. In this case, you need two swap the two dimensions of A to make an a valid input for the LSTM layer (again, the hidden_size
dimension is omitted here). You can do this with the following commands:
A1 = a.transpose(1,0)
A2 = a.permute(1,0)
If you print either A1 or A1 you get the correct result:
tensor([[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.]])
However, if you use view()
or reshape()
like
A3 = a.view(seq_len, -1)
A4 = a.reshape(seq_len, -1)
you get the wrong result:
tensor([[1., 1., 1.],
[1., 1., 2.],
[2., 2., 2.],
[2., 3., 3.],
[3., 3., 3.]])
Note that for A1
, A2
, A3
, A4
the shape is [5, 3]
. So with respect to the shape all results are correct, i.e., [seq_len, batch_size]
, the LSTM layer will accept each result without throwing any error.
However, A3
and A4
have messed up the batch semantically. The network might still learn something – after all, there’s usually a pattern to be found in everything – but it will learn the wrong pattern.
view()
and reshape()
obviously have their purpose, for example, to flatten tensors. But for swapping dimensions they are generally the wrong methods. I hope this avoids some head-scratching, particularly for beginners. And I can point to this post when I see the next question that most likely suffer from that issue :).