Hey guys, I’m currently studying transformer from Attention is All You Need and I’m trying to implement a simple text classifier by using the
nn.TransformerEncoder module and I have a question:
Considering that I have a batch size of
batch_size, a sequence length of
seq_len and an embedding size of
emb_size, the output of the transformer encoder has the shape
(batch_size, seq_len, emb_size). How can I send this output to a simple
nn.Linear module? Should I concatenate the values or do some kind of average or just pick the vector corresponding to the last word? Thanks!