Hi,
I am implementing nn.multiheadattention on resnet18 architecture.
My idea is to apply attenetion after every resnet block.
For Example:
after ResBlock1 my output is a tensor of size [64,64,64,64] (B,C,H,W)
I want to apply nn.multiheadattention on this output.
I am confused about multiheadattention part. Previously I was initializing nn.multiheadattention where embed_dim is equal to HxW of the output of resnet layer.
I was told to initialize embed_dim to be equal to the size of Channel dim of the resnet output.
Which one is the right approach?
one a side note, I get better accuracy when I use HxW as embed_dim in nn.multiheadattention.
Here is my code:
self.mha1 = nn.MultiheadAttention(embed_dim=4096, num_heads=8)
self.mha2 = nn.MultiheadAttention(embed_dim=1024, num_heads=8)
self.mha3 = nn.MultiheadAttention(embed_dim=256, num_heads=8)
self.mha4 = nn.MultiheadAttention(embed_dim=64, num_heads=8)
def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x) # [B, 64, 64, 64] [B,C,H,W]
y = x.clone()
y = y.view(-1, 64,64*64)
y, _ = self.mha1(y,y,y)
y = y.view(-1, 64,64,64)
x = x + y
x = self.layer2(x) # [B, 128, 32, 32] [B,C,H,W]
y = x.clone()
y = y.view(-1, 128,32*32)
y, _ = self.mha2(y,y,y)
y = y.view(-1, 128,32,32)
x = x + y
x = self.layer3(x) # [B, 256, 16, 16] [B,C,H,W]
y = x.clone()
y = y.view(-1, 256,16*16)
y, _ = self.mha3(y,y,y)
y = y.view(-1, 256,16,16)
x = x + y
x = self.layer4(x) # [B, 512, 8, 8] [B,C,H,W]
y = x.clone()
y = y.view(-1, 512,8*8)
y, _ = self.mha4(y,y,y)
y = y.view(-1, 512,8,8)
x = x + y
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x