I have been trying to implement a toy Vision Transformer from scratch. In the paper it is mentioned that there is an additional class embedding that is trained. My question is how can this be implemented in PyTorch. As I understand, nn.Embedding is a lookup table, if I use that it will expect me give an input to generate a tensor, which would not be possible during inference. What would be the correct way to implement the learnable class embedding?
Any help would be appreciated!