When the input becomes 0 (thatâs what padding does with value 0), the added up previous hidden and current input is smaller than ânormalâ (with ânormalâ inputs). Therefore, the first sigmoidâs output will be lower and therefore the cell state will forget more. However, that might not even be the biggest issue. By looking at the input gate, we can see, that the zero - centred tanh multiplied with the sigmoidâs output adds less to the cell state than normally. The same goes for the output gate, as the output gateâs sigmoid function outputs lower values than ânormallyâ. Therefore, even the hidden state gets closer and closer to zero, as well as the cell state with it. Furthermore, all these operations are accompanied by weights, which are supporting this convergence.
Since you define your LSTM with the default parameter batch_first=False, the output has the shape (seq_len, batch, hidden_size). That means that out[:, -1, :] gives you the values for the hidden states of all the time steps for the last item in your batch, i.e., the output shape is (seq_len, hidden_dim).
What you want is the last hidden state (âlastâ w.r.t. to the number of time steps) for all items in your batch. You simple change that line to out[-1]. Just to be sure, and since I donât know how your data looks like, can you change your forward() method as follows, and post the output of the print statements?
def forward(self, x):
print(x.shape)
out, _ = self.rnn(x)
print(out.shape)
out = out[:, -1, :])
print(out.shape)
out = self.out(out)
return out
def forward(self, x):
print('x shape',x.shape)
out, _ = self.rnn(x)
print('before reshape',out.shape)
out =out[:, -1, :]
print('after reshape',out.shape)
out = self.out(out)
print('output shape',out.shape)
return out
and the result is:
x shape torch.Size([2227, 1, 14])
before reshape torch.Size([2227, 1, 40])
after reshape torch.Size([2227, 40])
output shape torch.Size([2227, 1])
@hzzzm thanks! I assume that 2227 is the sequence length and you only have 1 sequence in your batch, i.e., batch_size = 1. After out = out[:, -1, :], the shape is (2227, 40) which would suddenly mean a batch size of 2227.
out = out[-1] will yield a shape if (1, 40) representing (batch_size, hidden_size) which I strongly assume is what you want.
I have removed the line: out= out[:,-1,:] from my function to keep batch_size. but it is still the same.
My data: A sequence of 2227, it has 14 properties