In my Neural network model, I represent each word with a 256 dimensional embedding vector. For a sentence with 8 words, I get 8x256 dimensional matrix. I want to give these matrix to lstm as input so that, LSTM process them one token at a time and then I am going to use its final hidden state.

According to pytorch documentation, the input should be in the shape of (seq_len, batch, input_size) . In my case, seq_len will be 8, batch will be 1 and input_size will be 256. My question is what is the correct way to convert my input to desired shape ? I don’t want to endup with a matrix whose values are dispositioned. I am quite new in PyTorch and row-major calculations, therefore I wanted to ask it here. I do it as follows, is it correct ?

x = torch.rand(8,256)
lstm_input = torch.reshape(x,(8,1,256))

Is it the correct way to do that or should I do something different like taking the transpose first ?
In addition to my specific question, I would be really grateful, if someone provide me a general rules I should careful while changing the shape of any matrix.

reshape() preserves the ordering of the elements in the tensor. So if x has shape (seq_len, input_size), x.reshape(seq_len, 1, input_size) will keep the correct ordering. As you described, it your seq_len is 8 and your input_size is 256.

If I had a tensor with shape (bs,seq_len,input_size) and I want to reshape it to the (bs,seq_len,1,input_size), could I still use the same technique ? I mean, is the following lines still valid ?

bs,seq_len,input_size= 5,20,128
x = torch.rand(bs,seq_len,input_size)
torch.reshape(x,(x.shape[0],x.shape[1],1,x.shape[2])

I basically wonder, in which cases should I be more careful for not to mixup values in the original tensor ?

Reshape preserves the ordering of x[0][0][0], x[0][0][1], ... x[m-1][n-1][p-1].

Reshape will not be appropriate if you have an input of shape (bs, seq_len, input_size) and want to pass it into the LSTM that takes (seq_len, bs, input_size). There you would need to use permute to transpose the dims: lstm_input = x.permute(1, 0, 2)