This is how the Encoder is working
class Encoder(nn.Module):
def __init__(self, num_classes):
super(Encoder, self).__init__()
self.initial_block = DownsamplerBlock(3,16)
self.layers = nn.ModuleList()
self.layers.append(DownsamplerBlock(16,64))
for x in range(0, 5): #5 times
self.layers.append(non_bottleneck_1d(64, 0.03, 1))
self.layers.append(DownsamplerBlock(64,128))
for x in range(0, 2): #2 times
self.layers.append(non_bottleneck_1d(128, 0.3, 2))
self.layers.append(non_bottleneck_1d(128, 0.3, 4))
self.layers.append(non_bottleneck_1d(128, 0.3, 8))
self.layers.append(non_bottleneck_1d(128, 0.3, 16))
#Only in encoder mode:
self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True)
def forward(self, input, predict=False):
output = self.initial_block(input)
print(output.size())
for layer in self.layers:
output = layer(output)
print(output.size())
exit()
if predict:
output = self.output_conv(output)
return output
These are the outputs after each layer:
(1, 16, 256, 341)
(1, 64, 128, 171)
(1, 64, 128, 171)
(1, 64, 128, 171)
(1, 64, 128, 171)
(1, 64, 128, 171)
(1, 64, 128, 171)
(1, 128, 64, 86)
(1, 128, 64, 86)
(1, 128, 64, 86)
(1, 128, 64, 86)
(1, 128, 64, 86)
(1, 128, 64, 86)
(1, 128, 64, 86)
(1, 128, 64, 86)
(1, 128, 64, 86)
(1, 13, 64, 86)
And the whole downsampler:
class DownsamplerBlock (nn.Module):
def __init__(self, ninput, noutput):
super(DownsamplerBlock, self).__init__()
self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True)
self.pool = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.bn = nn.BatchNorm2d(noutput, eps=1e-3)
def forward(self, input):
#transforms.RandomCrop(224)
output = torch.cat([self.conv(input), self.pool(input)], 1)
output = self.bn(output)
return F.relu(output)
class non_bottleneck_1d (nn.Module):
def __init__(self, chann, dropprob, dilated):
super(non_bottleneck_1d, self).__init__()
self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1,0), bias=True)
self.conv1x3_1 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1), bias=True)
self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))
self.conv1x3_2 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1, dilated))
self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
self.dropout = nn.Dropout2d(dropprob)
def forward(self, input):
output = self.conv3x1_1(input)
output = F.relu(output)
output = self.conv1x3_1(output)
output = self.bn1(output)
output = F.relu(output)
output = self.conv3x1_2(output)
output = F.relu(output)
output = self.conv1x3_2(output)
output = self.bn2(output)
if (self.dropout.p != 0):
output = self.dropout(output)
return F.relu(output+input) #+input = identity (residual connection)