You mentioned in your first post that you are not running _process_input()
. I thought you meant that it was inside your Encoder or something like that. But you need to run this in order to have the expected shape.
Here is where the the patches are flattened.
Also I do not see the class token
added.
Is there a reason you left the _process_input
method out?
Edit: 1. In the implementation from your link, the positional embedding is added inside the encoder.
2. You can access the class_token
from the parent, since it is defined as self.class_token = nn.Parameter(...)
in the VisionTransformer
class.