Visualize vision transformer forward pass

How i can visualize the forward function in vision transformer (pyramid vision transformer)?.
I know about CNN visualization but i do not have any idea. I want to visualize x from the first and last line

class PatchEmbed(nn.Module):
“”" Image to Patch Embedding
“”"

def __init__(self, img_size=112, patch_size=8, in_chans=3, embed_dim=768):
    super().__init__()
    img_size = to_2tuple(img_size)
    patch_size = to_2tuple(patch_size)

    self.img_size = img_size
    self.patch_size = patch_size
    # assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
    #     f"img_size {img_size} should be divided by patch_size {patch_size}."
    self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
    self.num_patches = self.H * self.W
    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    self.norm = nn.LayerNorm(embed_dim)

def forward(self, x):
    B, C, H, W = x.shape

    x = self.proj(x).flatten(2).transpose(1, 2)
    x = self.norm(x)
    H, W = H // self.patch_size[0], W // self.patch_size[1]

    return x, (H, W)