Say that we use the LSTM example from the pytorch docs:
>>> import torch.nn as nn >>> rnn = nn.LSTM(10,20,2) >>> x = torch.randn(5,3,10) >>> output, (hn, cn) = rnn(x)
And that we feed the last hidden state to a MLP to get a prediction.
How can we feed the outputs from the LSTM to Multiheadattention ??