If your remove batch_first=True
it’s of course batch_first=False
by default. In this case you would need to change out = self.fc(out[:, -1, :])
to out = self.fc(out[-1])
I don’t what you’re trying to learn and how your data looks like, but x = x.reshape(-1, INPUT_SIZE, MAX_STRING_SIZE)
looks a but suspicious. Firstly, why do you need to infer the batch size (I assume batch_first=True
and the first dimension is for the batch). And secondly, input_size
is expected to be in the 3rd dimension of the input tensors for a LSTM/GRU.
I’m not sure about your reshape()
in the return
line either. If you’re not careful reshape()
and view()
can quickly mess up your data.