Applying nn.multiheadattention on CNN (resnet)

Hi,
I am implementing nn.multiheadattention on resnet18 architecture.

My idea is to apply attenetion after every resnet block.

For Example:
after ResBlock1 my output is a tensor of size [64,64,64,64] (B,C,H,W)
I want to apply nn.multiheadattention on this output.

I am confused about multiheadattention part. Previously I was initializing nn.multiheadattention where embed_dim is equal to HxW of the output of resnet layer.

I was told to initialize embed_dim to be equal to the size of Channel dim of the resnet output.

Which one is the right approach?

one a side note, I get better accuracy when I use HxW as embed_dim in nn.multiheadattention.

Here is my code:

self.mha1 = nn.MultiheadAttention(embed_dim=4096, num_heads=8)
self.mha2 = nn.MultiheadAttention(embed_dim=1024, num_heads=8)
self.mha3 = nn.MultiheadAttention(embed_dim=256, num_heads=8)
self.mha4 = nn.MultiheadAttention(embed_dim=64, num_heads=8)

def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

    x = self.layer1(x) # [B, 64, 64, 64] [B,C,H,W]
    y = x.clone()
    y = y.view(-1, 64,64*64)
    y, _ = self.mha1(y,y,y)
    y = y.view(-1, 64,64,64)
    x = x + y

    x = self.layer2(x) # [B, 128, 32, 32] [B,C,H,W]
    y = x.clone()
    y = y.view(-1, 128,32*32)
    y, _ = self.mha2(y,y,y)
    y = y.view(-1, 128,32,32)
    x = x + y

    x = self.layer3(x) # [B, 256, 16, 16] [B,C,H,W]
    y = x.clone()
    y = y.view(-1, 256,16*16)
    y, _ = self.mha3(y,y,y)
    y = y.view(-1, 256,16,16)
    x = x + y

    x = self.layer4(x) # [B, 512, 8, 8] [B,C,H,W]
    y = x.clone()
    y = y.view(-1, 512,8*8)
    y, _ = self.mha4(y,y,y)
    y = y.view(-1, 512,8,8)
    x = x + y

    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.fc(x)

    return x