I think there is a type for the LSTMCell example given here at api doc page: https://pytorch.org/docs/stable/generated/torch.nn.LSTMCell.html
Current snippet is:
>>> rnn = nn.LSTMCell(10, 20) >>> input = torch.randn(3, 10) >>> hx = torch.randn(3, 20) >>> cx = torch.randn(3, 20) >>> output =  >>> for i in range(6): hx, cx = rnn(input[i], (hx, cx)) output.append(hx)
I think the input rather should be:
input = torch.randn(6, 3, 10)
Could someone please confirm these?