I need help to better understand Transformer block

i am studying coatnets which are a fusion of convnets and self attention. Now I would like some help understanding this pythorch code that I found on a repository and it is difficult for me to understand.
I am including a part of the code that I would like some help on:

class Attention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)
    self.ih, self.iw = image_size

    self.heads = heads
    self.scale = dim_head ** -0.5

    # parameter table of relative position bias
    self.relative_bias_table = nn.Parameter(
        torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

    coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
    coords = torch.flatten(torch.stack(coords), 1)
    relative_coords = coords[:, :, None] - coords[:, None, :]

    relative_coords[0] += self.ih - 1
    relative_coords[1] += self.iw - 1
    relative_coords[0] *= 2 * self.iw - 1
    relative_coords = rearrange(relative_coords, 'c h w -> h w c')
    relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
    self.register_buffer("relative_index", relative_index)

    self.attend = nn.Softmax(dim=-1)
    self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)

    self.to_out = nn.Sequential(
        nn.Linear(inner_dim, oup),
        nn.Dropout(dropout)
    ) if project_out else nn.Identity()

def forward(self, x):
    qkv = self.to_qkv(x).chunk(3, dim=-1)
    q, k, v = map(lambda t: rearrange(
        t, 'b n (h d) -> b h n d', h=self.heads), qkv)

    dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

    # Use "gather" for more efficiency on GPUs
    relative_bias = self.relative_bias_table.gather(
        0, self.relative_index.repeat(1, self.heads))
    relative_bias = rearrange(
        relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
    dots = dots + relative_bias

    attn = self.attend(dots)
    out = torch.matmul(attn, v)
    out = rearrange(out, 'b h n d -> b n (h d)')
    out = self.to_out(out)
    return out
class Transformer(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
        super().__init__()
        hidden_dim = int(inp * 4)
    self.ih, self.iw = image_size
    self.downsample = downsample

    if self.downsample:
        self.pool1 = nn.MaxPool2d(3, 2, 1)
        self.pool2 = nn.MaxPool2d(3, 2, 1)
        self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

    self.attn = Attention(inp, oup, image_size, heads, dim_head, dropout)
    self.ff = FeedForward(oup, hidden_dim, dropout)

    self.attn = nn.Sequential(
        Rearrange('b c ih iw -> b (ih iw) c'),
        PreNorm(inp, self.attn, nn.LayerNorm),
        Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
    )

    self.ff = nn.Sequential(
        Rearrange('b c ih iw -> b (ih iw) c'),
        PreNorm(oup, self.ff, nn.LayerNorm),
        Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
    )

def forward(self, x):
    if self.downsample:
        x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
    else:
        x = x + self.attn(x)
    x = x + self.ff(x)
    return x