Say that we use the LSTM example from the pytorch docs:
>>> import torch.nn as nn
>>> rnn = nn.LSTM(10,20,2)
>>> x = torch.randn(5,3,10)
>>> output, (hn, cn) = rnn(x)
And that we feed the last hidden state to a MLP to get a prediction.
How can we feed the outputs from the LSTM to Multiheadattention ??