I am currently trying to implement the following paper. And I defined these modules.
[1911.09737] Filter Response Normalization Layer: Eliminating Batch Dependence in the Training of Deep Neural Networks
Next, I want to convert BatchNorm2d + ReLU
to FRN + TLU
in my network.
(ex. ReLU
not after BatchNorm2d
is excluded)
First, I used bellow bnrelu_to_frn()
. However, this method does not consider forward order.
How can I convert specific continuous layers ?
class TLU(nn.Module):
def __init__(self, num_features):
super(TLU, self).__init__()
self.num_features = num_features
self.tau = nn.parameter.Parameter(
torch.Tensor(1, num_features, 1, 1), requires_grad=True)
def reset_parameters(self):
nn.init.zeros_(self.tau)
def forward(self, x):
return torch.max(x, self.tau)
class FRN(nn.Module):
def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
super(FRN, self).__init__()
self.num_features = num_features
self.init_eps = eps
self.is_eps_leanable = is_eps_leanable
self.weight = nn.parameter.Parameter(
torch.Tensor(1, num_features, 1, 1), requires_grad=True)
self.bias = nn.parameter.Parameter(
torch.Tensor(1, num_features, 1, 1), requires_grad=True)
if is_eps_leanable:
self.eps = nn.parameter.Parameter(torch.Tensor(1), requires_grad=True)
else:
self.register_buffer('eps', torch.Tensor([eps]))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.is_eps_leanable:
nn.init.constant_(self.eps, self.init_eps)
def forward(self, x):
nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
x = x * torch.rsqrt(nu2 + self.eps.abs())
x = self.weight * x + self.bias
return x
def bnrelu_to_frn(module):
mod = module
before_name = None
before_child = None
is_before_bn = False
for name, child in module.named_children():
if is_before_bn and isinstance(child, (ReLU, LeakyReLU)):
if isinstance(before_child, BatchNorm2d):
mod.add_module(
before_name, FRN(num_features=before_child.num_features))
else:
raise NotImplementedError()
mod.add_module(name, TLU(num_features=before_child.num_features))
else:
mod.add_module(name, bnrelu_to_frn(child))
before_name = name
before_child = child
is_before_bn = isinstance(child, BatchNorm2d)
return mod