I think I’ve wrapped my head around, mostly, how the vision transformer (ViT) works, but I’m a bit confused regarding how it makes use of the full embedding layer.
According to the paper, An image is worth 16x16 words, A class token is attached at the head of our dimensional embedding, and for classification, that class token is extracted and passed through a MLP layer block in order to then output a final classification. This part makes sense to me, however what I don’t understand is that if we are stripping off the
[class] token to pass through the MLP, how is the rest of the information in the embedded dimension ever seen or used when training with back propagation?
x = PatchAndEmbed( image) # breaks up image into patches and creates embed
x = torch.cat( class_token, x, dim=1) # Places a pre-declared class token parameter at the 0'th position of the embedding
x = x + PositionalEmbedding() # add sinusoidal positional embedding
x = TransformerEndoder( x) # Passes batch of embeddings through the transformer/encoder
class_token = x[:, 0] # Extracts class token from the embedding
output = mlp( class_token) # passes the class token through mlp layer block for categorical logits output.
I’ve created my model, which works, but I’m not sure how. I would expect that the back propagation process only ever passes back through the class token part of the embedding. How is the rest of that information ever used/trained?