How to create a very basic example of Transformers' per-pixel classification?

Despite having tons of Transformer material on the internet, I can’t find just a barely plain pytorch implementation of a single Transformer layer that receives and outputs an image. I’m struggling to figure out how it could be easily done.

Not regarding performance matters, I’d just like to figure out how to apply a simple operation with Transformers using pytorch code. For example, checking a simple CNN below, it receives a tensor of (B,3,224,224), and it outcomes a tensor of (B,13,224,224). My doubt is how could I do a similar operation with Transformers?

from torch import nn

class PlainCNN(nn.Module):
    def __init__(self):
        self.cnn1 = nn.Conv2d(3, 100, 3, padding='same')
        self.cnn2 = nn.Conv2d(100, 100, 3, padding='same')
        self.cnn3 = nn.Conv2d(100, 13, 3, padding='same')
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.cnn1(x)
        x = self.relu(x)
        x = self.cnn2(x)
        x = self.relu(x)
        x = self.cnn3(x)
        return x

cnn = PlainCNN()
# torch.Size([1, 13, 224, 224])