ViT for segmentation


Currently, I am trying to apply a ViT transformer as a backbone for my image segmentation.
Using a pre-trained ViT, I obtain the following summary for my images:

Of course, as one can see, I transformed the output to be of size [1,35] and hence being a classification procedure for each image. Therefore, I would like to transform the last output of the encoder (ViT, either of size [1,35] or the standard [1,1000] size) into the following shape of [1,num_classes,384,384]. This is compatible with an image segmentation task.

Reading several papers online (e.g., I find some interesting observations, but they do not clearly indicate how to address such problems in terms of code. Therefore, I was wondering if somebody could point me in the right direction to scale the image back to its original format to perform image segmentation.


Often an encoder-decoder model architecture is used where the decoder uses the (low-dimensional) output of the encoder to create the segmentation mask e.g. via transposed convolutions etc.
I would guess the linked paper might apply a similar technique.