How do I use nn.TransformerEncoder to make a classifier?

Hey guys, I’m currently studying transformer from Attention is All You Need and I’m trying to implement a simple text classifier by using the nn.TransformerEncoder module and I have a question:

Considering that I have a batch size of batch_size, a sequence length of seq_len and an embedding size of emb_size, the output of the transformer encoder has the shape (batch_size, seq_len, emb_size). How can I send this output to a simple nn.Linear module? Should I concatenate the values or do some kind of average or just pick the vector corresponding to the last word? Thanks!

Have you checked the PyTorch tutorial on the Transformer models? You can see the last layer is a linear layer with output size of the number of tokens, because it is a LM model, in your case the output will be the text class of the input. I would just base your work on that guide as it is really nicely written.

Thanks for your reply! I’m following this tutorial but I always get an output size of [batch_size, seq_len, output_dim] where output_dim is the output of my fully connected layer at the end of the transformer encoder. What is the best approach to make the predictions in the shape [batch_size, output_dim]?

I think you could maybe reshape the vector by multiplying the seq_len*output_dim reshape(batch_size, seq_len*output_dim) before you pass it to the linear layer. I do not know if it is the most appropriate way to go for it, also looking at how they deal with this in the huggingface transformer library might help.
If you try this method please let know about the results :slight_smile:

1 Like