For beginners: Do not use view() or reshape() to swap dimensions of tensors!

Being more of an NLP person and dealing regulary with LSTMs or GRUs – but this is a general issue, I think – I’ve noticed that many people make a fundamental mistake. I’ve seen it many Github projects I’ve tried to reproduce but also here in the forum (usually something like: “My network runs but does not train/learn properly”, even for arguably simple networks).

When using LSTMs or GRUs, the input and/or output do not have the right shape. For example, the input batch has the shape [batch_size, seq_len, hidden_size], but an LSTM without batch_first=True assumes an input of shape [seq_len, batch_size, hidden_size]. Many people incorrectly use view() or reshape() to fix the shape. While it does fix the shape, it messes up the data and essentially prohibits proper training (e.g., the loss is not going down).

The correct way here is to use transpose() or permute() to swap dimensions.

To illustrate the problem, let’s create an example tensor with a batch_size and seq_len dimension; I omit the hidden_size dimension to keep things simple

batch_size, seq_len = 3, 5
A = torch.zeros((batch_size, seq_len))
A[0,0:] = 1
A[1,0:] = 2
A[2,0:] = 3

This gives a tensor with shape [batch_size, seq_len] looking like that:

tensor([[1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.]])

Now let’s say you have defined an nn.LSTM layer with batch_first=False. In this case, you need two swap the two dimensions of A to make an a valid input for the LSTM layer (again, the hidden_size dimension is omitted here). You can do this with the following commands:

A1 = a.transpose(1,0)
A2 = a.permute(1,0)

If you print either A1 or A1 you get the correct result:

tensor([[1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.]])

However, if you use view() or reshape() like

A3 = a.view(seq_len, -1)
A4 = a.reshape(seq_len, -1)

you get the wrong result:

tensor([[1., 1., 1.],
        [1., 1., 2.],
        [2., 2., 2.],
        [2., 3., 3.],
        [3., 3., 3.]])

Note that for A1, A2, A3, A4 the shape is [5, 3]. So with respect to the shape all results are correct, i.e., [seq_len, batch_size], the LSTM layer will accept each result without throwing any error.

However, A3 and A4 have messed up the batch semantically. The network might still learn something – after all, there’s usually a pattern to be found in everything – but it will learn the wrong pattern.

view() and reshape() obviously have their purpose, for example, to flatten tensors. But for swapping dimensions they are generally the wrong methods. I hope this avoids some head-scratching, particularly for beginners. And I can point to this post when I see the next question that most likely suffer from that issue :).

63 Likes

Thanks for the answer! This seems really useful.

I am curious, what would be uses cases for .view then? (perhaps making a list would be useful)

The only places I’ve ever seen it used without it causing issues is when flattening (as you mentioned) and when doing a squeeze (despite a squeeze function already existing). e.g.

x = torch.randn(3, 4, 5, 6)
x.view(batch_size, -1)  # flat!

or (saw it on the imagenet code for compute topk accuracy I believe)

x.view(1,-1)

It just seems that .view is weird since it only arranges the tensor to the right view but starts from the beginning and enumerates “in order” (ref: Tensor.view is misleading - #2 by rasbt).

1 Like

So x.view(batch_size, -1) is an example of the broader pattern of combining dimensions. There may be other applications, for example when you want a matrix-multiplication-like operation where the three matrix “axes” you have in vanilla matrix multiplications have different shapes. Fun fact: the very versatile einsum reduces to batch matrix multiplications. It does so using permute and view/reshape.

The other use, dual if you want, is splitting dimensions. In addition to this happing as the backward of combining dimensions this can also be useful in things like very simple downscaling (I sometimes use x.view(batch_size, c, h // 2, 2, w // 2, 2) to then do something with the dimensions of size 2 - if you take the max over both, you get a maxpool, but you could also do a logsumexp-pool or somesuch – your imagination is the limit!

Best regards

Thomas

3 Likes

Thanks for the elaboration on this matter. I was struggling with this issue for 6 months in the end!

There is still another use-case left. What if your training data is flat in the first place (batch_size, data)?
This scenario could come from a visual encoder. Its result could be fed into an LSTM and therefore needs to be shaped accordingly.

A reshape to (num sequences, sequence length, data) is just fine, but the other way around to a shape of (sequence length, num sequences, data) is causing the troubles that you describe.

So what would be good way to reshape the tensor to sequence length first in this case?

It depends what data in your batch exactly refers to. Since you want to push it through I LSTM, I assume that data contains some sequence somewhere. If so, it really depends on the content/structure of data.

Or do you have a sequence of, say, 32 frames of a video clip you first encode and then push these 32 frames as a sequence through an LSTM. In this case, batch_size for the visual encoding becomes your seq_len for the LSTM with a batch_size=1. So you don’t need to reshape anything but “unsqueeze” your batch dimension. Say h is your batch after visual encoding with shape (32, data). Then you can do:

h = h.unsqueeze(0)

resulting in shape of (1, 32, data) which works just fine as input for your LSTM.

The conversion between sequence first and batch first is a usually a .permute(1, 0, 2) (could also be written as transpose, but to me permute better captures the intent).

1 Like

My particular use-case originates from reinforcement learning with pixel observation. Lets assume that I sample 16 agent-environment-interactions and use a sequence length of 8.

I’d feed a tensor of shape (16, 3, 84, 84) to a stack of convolutional layers. After that the data needs to be reshaped into sequences of length 8. Doing so by using sequence first will mess up the data. (reshape(8, 2, -1)). Using the batch size as first dimension works well concerning the reshape.

hello, your answer is very useful for me who are a beginner in AI.But in cv, we always need to flatten the matrix([M,C,H,W]) to a row vector([M,CHW]). so how can I flatten this matrix without in-place option. It brother me long time, I’m so grateful if you can reply

Hi @vdw thank you for your answer and referring this post in my post. I was wondering if the data is a 128*128 matrix will this approach still work? Or do I need to flatten the matrix? Thank you so much.

I’m not sure if I fully understand your questions.

The issue of reshaping/transforming the matrix is not a issue of the shape of the matrix or tensor, but how the matrix or tensors is actually organized in memory. For example a 2d matrix is not exactly stored as a 2d grid in memory but as an array of length m*n (assuming m and n are the dimensions of the matrix). Which element in the array represent a specific cell in the matrix depends if the matrix is stored row- and column-major order.

In short, when all you want to do is to transpose (i.e., “rotate” a tensor by swapping dimensions around), you should never use reshape() or view(). Under certain conditions it might work, but no generally.

However, I’m not sure what your task is and what you want to do with your 128x128 matrix.

Hi @vdw referring to this question, let my input sequence is 24 images and my input is organised as a sequence 3224128128 where 32 is the bathsize. 24 is the sequence length and 128128 is an image. I want to use LSTM to encode the temporal dependence in the data. How should I format the LSTM input?

Ok, I got you now. Yes, if you have something like (batch_size, seq_len, width, height) you need to transform into (batch_size, seq_len, width*height) to serve as valid input for the nn.LSTM. This means you have to flatten the image, i.e., flatten your last 2 dimensions. The easiest way should be to do:

h = h.view(batch_size, seq_len, -1)

But again, this assume that h.shape was initially (batch_size, seq_len, width, height).

If for some reasons h starts out with (seq_len, batch_size, width, height) you need to the additional step:

h = h.transpose(1,0) # or: h = h.permute(1,0,2)
h = h.view(batch_size, seq_len, -1)
1 Like

Thank you so much @vdw