I am trying to implement Split Brain Auto-encoder in pytorch. In their implementation first they pre train 2 networks after splitting across channel dimensions then after combining the channels and absorbing Batch Norm layer weights into Convolution layer weights. Then finally perform Semantic segmentation task. Paper Reference (Implementation is in Appendix, Page 9)
I am not able to understand the significance of absorbing BatchNorm and if there is any significance how to implement in pytorch. My initial network is:
class AlexNet_BN(nn.Module):
def __init__(self, in_channel=3,out_channel=3, layers=[96,256,384,384,256],out_size=180):
super(AlexNet_BN, self).__init__()
self.out_size = out_size
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, layers[0], kernel_size=11, stride=4, padding=2),#padding 5
nn.BatchNorm2d(layers[0]),
nn.ReLU(inplace=True)
)
self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2,padding=0)#padding 1
self.conv2 = nn.Sequential(
nn.Conv2d(layers[0], layers[1], kernel_size=5,stride=1, padding=2),
nn.BatchNorm2d(layers[1]),
nn.ReLU(inplace=True)
)
self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)#padding 1
self.conv3 = nn.Sequential(
nn.Conv2d(layers[1], layers[2], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(layers[2]),
nn.ReLU(inplace=True)
)
self.conv4 = nn.Sequential(
nn.Conv2d(layers[2], layers[3], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(layers[3]),
nn.ReLU(inplace=True)
)
self.conv5 = nn.Sequential(
nn.Conv2d(layers[3], layers[4], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(layers[4]),
nn.ReLU(inplace=True)
)
self.pool5 = nn.MaxPool2d(kernel_size=3, stride=1)#padding 1 and stride 1
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.pool5(x)
return x
Appreciate, if some one can help me in this regard.
Thanks