How to convert 2D self attention to 3D self attention

Hi All,

I am working on 3d image classification tasks and need 3D self-attention with relative position representation in my network.
I found a attention with relative position represnetation block for 2D image, but i could not convert it to 3D version, i have looked online and did not find the 3D version of it. can anyone tell me what should i change and which part does not need to be changed? thank you advance.

here is the attention block i want to convert:

    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih * self.iw, w=self.ih * self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out```


the 2D input size for this attention block and input sizes in the attention block is like this:

```input_tensor = torch.randn(1, 3, 32, 32)

# Reshape the input tensor to match the shape required by the Attention module
reshaped_input = Rearrange('b c ih iw -> b (ih iw) c')(input_tensor)

# Instantiate the Attention module
attention_module = Attention(inp=3, oup=32, image_size=(32, 32))

# Forward pass through the Attention module
output = attention_module(reshaped_input)

# Check the shape of the output tensor
print("Output shape:", output.shape)
# input shape torch.Size([1, 1024, 3])
# q shape torch.Size([1, 8, 1024, 32])
# k shape torch.Size([1, 8, 1024, 32])
# v shape torch.Size([1, 8, 1024, 32])
# dots shape torch.Size([1, 8, 1024, 1024])
# relative bias torch.Size([1048576, 8])
# relative bias after rearranged torch.Size([1, 8, 1024, 1024])
# dots_ dots + relative bias torch.Size([1, 8, 1024, 1024])
# attn  torch.Size([1, 8, 1024, 1024])
# out: attn@v  torch.Size([1, 8, 1024, 32])
# out: rearraned  torch.Size([1, 1024, 256])
# out: slef.to_out  torch.Size([1, 1024, 32])
# Output shape: torch.Size([1, 1024, 32])```

here is 3D self-attention with relative position representation.
the relative position bias part is adopted from windowAttention3D class from video swin transformer architecture. I am not quite sure how this thing works in 3D, but according to printed input shapes, the dimensions and input, outputs, dims are matching as expected.

video swin transformer github link video swin transformer

    def __init__(self, inp, oup, image_size, heads=4, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.id, self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * self.id - 1) * (2 * self.ih - 1)*((2 * self.iw - 1)), heads)) ## 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
        print("relative bias table", self.relative_position_bias_table.shape)
         # get pair-wise relative position index for each token inside the window
        coords_d = torch.arange(self.id)
        coords_h = torch.arange(self.ih)
        coords_w = torch.arange(self.iw)
        coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))  # 3, Wd, Wh, Ww

        coords_flatten = torch.flatten(coords, 1)  # 3, Wd*Wh*Ww
        print("coord after flatten", coords.shape)

        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 3, Wd*Wh*Ww, Wd*Wh*Ww
        print("relative_coords", relative_coords.shape)
        


        relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
        print("relative coord shape after permute", relative_coords.shape)
        relative_coords[:, :, 0] += self.id - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.ih - 1
        relative_coords[:, :, 2] += self.iw - 1

        relative_coords[:, :, 0] *= (2 * self.ih - 1) * (2 * self.iw - 1)
        relative_coords[:, :, 1] *= (2 * self.iw - 1)
        relative_position_index = relative_coords.sum(-1)  # Wd*Wh*Ww, Wd*Wh*Ww
        print("relative_potition index", relative_position_index.shape)
        self.register_buffer("relative_position_index", relative_position_index)


        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        """ 
            x: input features with shape of (b (ih iw) c)
        """
        print("input size shape", x.shape)
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        
        print("q shape", q.shape) 
       
        print("k shape", k.shape) 
        print("v shape", v.shape) #

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # [1, 8, 1024, 1024]
        print("dots shape", dots.shape)

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
        -1, self.id * self.ih * self.iw, self.heads)  # Wd*Wh*Ww,Wd*Wh*Ww,nH
        print("relative bias", relative_position_bias.shape)

        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wd*Wh*Ww, Wd*Wh*Ww

        print("relative bias after permute ", relative_position_bias.shape) # 1, 8, 1024, 1024

        dots = dots + relative_position_bias
        print("dots_ dots + relative bias", dots.shape)

        attn = self.attend(dots)
        print("attn ", attn.shape)

        out = torch.matmul(attn, v)
        print("out: attn@v ", out.shape)

        out = rearrange(out, 'b h n d -> b n (h d)')
        print("out: rearraned ", out.shape)

        out = self.to_out(out)
        print("out: slef.to_out ", out.shape)

        return out```


you can test with the class with this input, make sure you downloaded ```einops```


```input_tensor = torch.randn(1, 3, 5, 10, 10)

# Reshape the input tensor to match the shape required by the Attention module
reshaped_input = Rearrange('b c id ih iw -> b (id ih iw) c')(input_tensor)

# Instantiate the Attention module
attention_module = Attention333(inp=3, oup=32, image_size=(5, 10, 10))

# Forward pass through the Attention module
output = attention_module(reshaped_input)

# Check the shape of the output tensor
print("Output shape:", output.shape)````



printed outputs


````relative bias table torch.Size([3249, 4])
coord after flatten torch.Size([3, 5, 10, 10])
relative_coords torch.Size([3, 500, 500])
relative coord shape after permute torch.Size([500, 500, 3])
relative_potition index torch.Size([500, 500])
input size shape torch.Size([1, 500, 3])
q shape torch.Size([1, 4, 500, 32])
k shape torch.Size([1, 4, 500, 32])
v shape torch.Size([1, 4, 500, 32])
dots shape torch.Size([1, 4, 500, 500])
relative bias torch.Size([500, 500, 4])
relative bias after permute  torch.Size([4, 500, 500])
dots_ dots + relative bias torch.Size([1, 4, 500, 500])
attn  torch.Size([1, 4, 500, 500])
out: attn@v  torch.Size([1, 4, 500, 32])
out: rearraned  torch.Size([1, 500, 128])
out: slef.to_out  torch.Size([1, 500, 32])
Output shape: torch.Size([1, 500, 32])```