Regarding LSTMCell example in documentation

I think there is a type for the LSTMCell example given here at api doc page: https://pytorch.org/docs/stable/generated/torch.nn.LSTMCell.html
Current snippet is:

>>> rnn = nn.LSTMCell(10, 20)
>>> input = torch.randn(3, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
        hx, cx = rnn(input[i], (hx, cx))
        output.append(hx)

I think the input rather should be:

input = torch.randn(6, 3, 10)

Could someone please confirm these?

Yes, that’s right and it was already fixes in the master docs. :slight_smile:

1 Like