Hierarchical Attention for few shot Image Classification

Is there any existing implementation of hierarchical attention for image classification, or hierarchical attention for text, that could be applied to images, that does not use LSTM, or GRU, or RNN, only attention?

How should I approach this problem?

till now I have done something like this,

def conv_block(in_channels, out_channels, k):
    # set_trace()
    # inpp = nn.TransformerEncoderLayer(512, 2)
    return nn.Sequential(
        AttentionStem(in_channels, out_channels, kernel_size=k, padding=1),
        # nn.TransformerEncoder(inpp, 1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

from IPython.core.debugger import set_trace

class Top(nn.Module):
  def __init__(self):
    super().__init__()
    # set_trace()
    self.encoder = conv_block(3, 16, 3)
    self.lin = nn.Linear(20, 10)
    self.childone = Second()
    self.childtwo = Second()
  def forward(self, x):
    # set_trace()
    # x = self.encoder(x)
    a = self.childone(self.encoder(x))
    b = self.childtwo(self.encoder(x))
    # print('top', a.shape, b.shape)
    out = torch.cat((a, b), dim=-1)
    return self.lin(out) 

class Second(nn.Module):
  def __init__(self):
    super().__init__()
    # set_trace()
    self.encoder = conv_block(16, 32, 3)
    self.lin = nn.Linear(20, 10)
    self.childone = Middle()
    self.childtwo = Middle()

  def forward(self, x):
    # set_trace()
    a = self.childone(self.encoder(x))
    b = self.childtwo(self.encoder(x))
    # print('middle', a.shape, b.shape)
    out = torch.cat((a, b), dim=-1)
    return self.lin(out)

class Middle(nn.Module):
  def __init__(self):
    super().__init__()
    # set_trace()
    self.encoder = conv_block(32, 64, 1)
    self.lin = nn.Linear(20, 10)
    self.childone = Bottom()
    self.childtwo = Bottom()

  def forward(self, x):
    # set_trace()
    a = self.childone(self.encoder(x))
    b = self.childtwo(self.encoder(x))
    # print('middle', a.shape, b.shape)
    out = torch.cat((a, b), dim=-1)
    return self.lin(out)

# class AboveBottom(nn.Module):
#   def __init__(self):
#     super().__init__()
#     # set_trace()
#     self.encoder = conv_block(64, 128, 1)
#     self.lin = nn.Linear(20, 10)
#     self.childone = Bottom()
#     self.childtwo = Bottom()

#   def forward(self, x):
#     # set_trace()
#     a = self.childone(self.encoder(x))
#     b = self.childtwo(self.encoder(x))
#     # print('middle', a.shape, b.shape)
#     out = torch.cat((a, b), dim=-1)
#     return self.lin(out)

class Bottom(nn.Module):
  def __init__(self):
    super().__init__()
    # set_trace()
    self.encoder = conv_block(64, 128, 1)
    self.lin_one = nn.Linear(512, 10)
  def forward(self, x):
    # set_trace()
    # print('bottom', x.shape)
    out = self.encoder(x)
    return (self.lin_one(out.view(out.size(0), -1)))

model = Top()

is this a correct way to create a hierarchy?
where AttentionStem is from

Typically CNNs have decreasing spatial resolution, so the typical thing would be to use some of the resolution levels as hierarchy levels. The next thing is how to formulate the attention. The classic K. Xu et al.: Show, attend and tell uses “positional” attention masks while Lu et al.: Knowing when to look have a query-based attention.
It would be interesting to hear about your results and experiences.

Best regards

Thomas