Hi friends,
I have a question maybe sounds kind of stupid. I have read a new designed block called ACNet, which replace square kernel with horizontal, vertical and square. Here is the block looks like:
class ACBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False,
use_affine=True, reduce_gamma=False, use_last_bn=False, gamma_init=None):
super(ACBlock, self).__init__()
self.deploy = deploy
if deploy:
self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size), stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
else:
self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size), stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=False, padding_mode=padding_mode)
self.square_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
center_offset_from_origin_border = padding - kernel_size // 2
ver_pad_or_crop = (padding, center_offset_from_origin_border)
hor_pad_or_crop = (center_offset_from_origin_border, padding)
if center_offset_from_origin_border >= 0:
self.ver_conv_crop_layer = nn.Identity()
ver_conv_padding = ver_pad_or_crop
self.hor_conv_crop_layer = nn.Identity()
hor_conv_padding = hor_pad_or_crop
else:
self.ver_conv_crop_layer = CropLayer(crop_set=ver_pad_or_crop)
ver_conv_padding = (0, 0)
self.hor_conv_crop_layer = CropLayer(crop_set=hor_pad_or_crop)
hor_conv_padding = (0, 0)
self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1), stride=stride,
padding=ver_conv_padding, dilation=dilation, groups=groups, bias=False, padding_mode=padding_mode)
self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size), stride=stride,
padding=hor_conv_padding, dilation=dilation, groups=groups, bias=False, padding_mode=padding_mode)
self.ver_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
self.hor_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
if reduce_gamma:
assert not use_last_bn
self.init_gamma(1.0 / 3)
if use_last_bn:
assert not reduce_gamma
self.last_bn = nn.BatchNorm2d(num_features=out_channels, affine=True)
if gamma_init is not None:
assert not reduce_gamma
self.init_gamma(gamma_init)
def init_gamma(self, gamma_value):
init.constant_(self.square_bn.weight, gamma_value)
init.constant_(self.ver_bn.weight, gamma_value)
init.constant_(self.hor_bn.weight, gamma_value)
print('init gamma of square, ver and hor as ', gamma_value)
def single_init(self):
init.constant_(self.square_bn.weight, 1.0)
init.constant_(self.ver_bn.weight, 0.0)
init.constant_(self.hor_bn.weight, 0.0)
print('init gamma of square as 1, ver and hor as 0')
def forward(self, input):
if self.deploy:
return self.fused_conv(input)
else:
square_outputs = self.square_conv(input)
square_outputs = self.square_bn(square_outputs)
vertical_outputs = self.ver_conv_crop_layer(input)
vertical_outputs = self.ver_conv(vertical_outputs)
vertical_outputs = self.ver_bn(vertical_outputs)
horizontal_outputs = self.hor_conv_crop_layer(input)
horizontal_outputs = self.hor_conv(horizontal_outputs)
horizontal_outputs = self.hor_bn(horizontal_outputs)
result = square_outputs + vertical_outputs + horizontal_outputs
if hasattr(self, 'last_bn'):
return self.last_bn(result)
return result
So my question is, this block has already add BN layer after each kernel(horizontal, vertical and square), if I wanna re-implement this work, when I define the ResNet, do I still need to add one more bn layer after this block? Like this:
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = AC_conv(planes, planes, kernel_size=3, stride=stride, padding=1)
#Is this line needed?
self.bn2 = nn.BatchNorm2d(planes)
#
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
#Is this line needed?
out = self.bn2(out)
#
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += identity
out = self.relu(out)
return out
Any suggestion is very thankful!