Hi All,
I am working on 3d image classification tasks and need 3D self-attention with relative position representation in my network.
I found a attention with relative position represnetation block for 2D image, but i could not convert it to 3D version, i have looked online and did not find the 3D version of it. can anyone tell me what should i change and which part does not need to be changed? thank you advance.
here is the attention block i want to convert:
def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == inp)
self.ih, self.iw = image_size
self.heads = heads
self.scale = dim_head ** -0.5
# parameter table of relative position bias
self.relative_bias_table = nn.Parameter(
torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))
coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
coords = torch.flatten(torch.stack(coords), 1)
relative_coords = coords[:, :, None] - coords[:, None, :]
relative_coords[0] += self.ih - 1
relative_coords[1] += self.iw - 1
relative_coords[0] *= 2 * self.iw - 1
relative_coords = rearrange(relative_coords, 'c h w -> h w c')
relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
self.register_buffer("relative_index", relative_index)
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, oup),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(
t, 'b n (h d) -> b h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# Use "gather" for more efficiency on GPUs
relative_bias = self.relative_bias_table.gather(
0, self.relative_index.repeat(1, self.heads))
relative_bias = rearrange(
relative_bias, '(h w) c -> 1 c h w', h=self.ih * self.iw, w=self.ih * self.iw)
dots = dots + relative_bias
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out```
the 2D input size for this attention block and input sizes in the attention block is like this:
```input_tensor = torch.randn(1, 3, 32, 32)
# Reshape the input tensor to match the shape required by the Attention module
reshaped_input = Rearrange('b c ih iw -> b (ih iw) c')(input_tensor)
# Instantiate the Attention module
attention_module = Attention(inp=3, oup=32, image_size=(32, 32))
# Forward pass through the Attention module
output = attention_module(reshaped_input)
# Check the shape of the output tensor
print("Output shape:", output.shape)
# input shape torch.Size([1, 1024, 3])
# q shape torch.Size([1, 8, 1024, 32])
# k shape torch.Size([1, 8, 1024, 32])
# v shape torch.Size([1, 8, 1024, 32])
# dots shape torch.Size([1, 8, 1024, 1024])
# relative bias torch.Size([1048576, 8])
# relative bias after rearranged torch.Size([1, 8, 1024, 1024])
# dots_ dots + relative bias torch.Size([1, 8, 1024, 1024])
# attn torch.Size([1, 8, 1024, 1024])
# out: attn@v torch.Size([1, 8, 1024, 32])
# out: rearraned torch.Size([1, 1024, 256])
# out: slef.to_out torch.Size([1, 1024, 32])
# Output shape: torch.Size([1, 1024, 32])```