How can I convert specific continuous layers?

I am currently trying to implement the following paper. And I defined these modules.
[1911.09737] Filter Response Normalization Layer: Eliminating Batch Dependence in the Training of Deep Neural Networks

Next, I want to convert BatchNorm2d + ReLU to FRN + TLU in my network.
(ex. ReLU not after BatchNorm2d is excluded)
First, I used bellow bnrelu_to_frn(). However, this method does not consider forward order.

How can I convert specific continuous layers ?

class TLU(nn.Module):
    def __init__(self, num_features):
        super(TLU, self).__init__()
        self.num_features = num_features
        self.tau = nn.parameter.Parameter(
            torch.Tensor(1, num_features, 1, 1), requires_grad=True)

    def reset_parameters(self):

    def forward(self, x):
        return torch.max(x, self.tau)

class FRN(nn.Module):
    def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
        super(FRN, self).__init__()

        self.num_features = num_features
        self.init_eps = eps
        self.is_eps_leanable = is_eps_leanable

        self.weight = nn.parameter.Parameter(
            torch.Tensor(1, num_features, 1, 1), requires_grad=True)
        self.bias = nn.parameter.Parameter(
            torch.Tensor(1, num_features, 1, 1), requires_grad=True)
        if is_eps_leanable:
            self.eps = nn.parameter.Parameter(torch.Tensor(1), requires_grad=True)
            self.register_buffer('eps', torch.Tensor([eps]))

    def reset_parameters(self):
        if self.is_eps_leanable:
            nn.init.constant_(self.eps, self.init_eps)

    def forward(self, x):
        nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
        x = x * torch.rsqrt(nu2 + self.eps.abs())
        x = self.weight * x + self.bias
        return x

def bnrelu_to_frn(module):
    mod = module
    before_name = None
    before_child = None
    is_before_bn = False

    for name, child in module.named_children():
        if is_before_bn and isinstance(child, (ReLU, LeakyReLU)):
            if isinstance(before_child, BatchNorm2d):
                    before_name, FRN(num_features=before_child.num_features))
                raise NotImplementedError()
            mod.add_module(name, TLU(num_features=before_child.num_features))
            mod.add_module(name, bnrelu_to_frn(child))

        before_name = name
        before_child = child
        is_before_bn = isinstance(child, BatchNorm2d)
    return mod
1 Like

@ptrblck any solution for this?

As explained by the author of this question, the execution order is not maintained by using .named_children() and modules will be returned based on the initialization. I don’t know if there is a clean way to perform this type of model surgery automatically without e.g. overriding some modules manually.