Is there any existing implementation of hierarchical attention for image classification, or hierarchical attention for text, that could be applied to images, that does not use LSTM, or GRU, or RNN, only attention?
How should I approach this problem?
till now I have done something like this,
def conv_block(in_channels, out_channels, k):
# set_trace()
# inpp = nn.TransformerEncoderLayer(512, 2)
return nn.Sequential(
AttentionStem(in_channels, out_channels, kernel_size=k, padding=1),
# nn.TransformerEncoder(inpp, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)
from IPython.core.debugger import set_trace
class Top(nn.Module):
def __init__(self):
super().__init__()
# set_trace()
self.encoder = conv_block(3, 16, 3)
self.lin = nn.Linear(20, 10)
self.childone = Second()
self.childtwo = Second()
def forward(self, x):
# set_trace()
# x = self.encoder(x)
a = self.childone(self.encoder(x))
b = self.childtwo(self.encoder(x))
# print('top', a.shape, b.shape)
out = torch.cat((a, b), dim=-1)
return self.lin(out)
class Second(nn.Module):
def __init__(self):
super().__init__()
# set_trace()
self.encoder = conv_block(16, 32, 3)
self.lin = nn.Linear(20, 10)
self.childone = Middle()
self.childtwo = Middle()
def forward(self, x):
# set_trace()
a = self.childone(self.encoder(x))
b = self.childtwo(self.encoder(x))
# print('middle', a.shape, b.shape)
out = torch.cat((a, b), dim=-1)
return self.lin(out)
class Middle(nn.Module):
def __init__(self):
super().__init__()
# set_trace()
self.encoder = conv_block(32, 64, 1)
self.lin = nn.Linear(20, 10)
self.childone = Bottom()
self.childtwo = Bottom()
def forward(self, x):
# set_trace()
a = self.childone(self.encoder(x))
b = self.childtwo(self.encoder(x))
# print('middle', a.shape, b.shape)
out = torch.cat((a, b), dim=-1)
return self.lin(out)
# class AboveBottom(nn.Module):
# def __init__(self):
# super().__init__()
# # set_trace()
# self.encoder = conv_block(64, 128, 1)
# self.lin = nn.Linear(20, 10)
# self.childone = Bottom()
# self.childtwo = Bottom()
# def forward(self, x):
# # set_trace()
# a = self.childone(self.encoder(x))
# b = self.childtwo(self.encoder(x))
# # print('middle', a.shape, b.shape)
# out = torch.cat((a, b), dim=-1)
# return self.lin(out)
class Bottom(nn.Module):
def __init__(self):
super().__init__()
# set_trace()
self.encoder = conv_block(64, 128, 1)
self.lin_one = nn.Linear(512, 10)
def forward(self, x):
# set_trace()
# print('bottom', x.shape)
out = self.encoder(x)
return (self.lin_one(out.view(out.size(0), -1)))
model = Top()
is this a correct way to create a hierarchy?
where AttentionStem is from