How to change all Relu to TLU in pretrained models?

i was trying to replace all the Relu activation functions with TLU of skresnet34 model(from timm library) as shown here : PyTorch-FilterResponseNormalizationLayer/frn.py at master · yukkyo/PyTorch-FilterResponseNormalizationLayer · GitHub and in filter response normalization layer paper.

i was trying this :


class TLU(nn.Module):
    def __init__(self, num_features = 32):
        """max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau"""
        super(TLU, self).__init__()
        self.num_features = num_features
        self.tau = nn.parameter.Parameter(
            torch.Tensor(1, num_features, 1, 1), requires_grad=True)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.zeros_(self.tau)

    def extra_repr(self):
        return 'num_features={num_features}'.format(**self.__dict__)

    def forward(self, x):
        return torch.max(x, self.tau)
    

def convert_relu_to_tlu(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            setattr(model, child_name, TLU())
        else:
            convert_relu_to_tlu(child)

def getter(model, name):
    layer = model
    for attrib in name.split("."):
        layer = getattr(layer, attrib)
    return layer

def setter(model, name, layer):
    try:
        attrib, name = name.rsplit(".", 1)
        model = getter(model, attrib)
    except ValueError:
        pass
    setattr(model, name, layer)

for name, module in model.named_modules():
    if isinstance(module, nn.ReLU):
        relu = getter(model, name)
        print(relu)
        #tlu = TLU(num_features = relu.num_features)
        tlu = TLU()
        print("changing {} with {}".format(relu, tlu))

        setter(model, name, tlu)

it replaces all relu with TLU but i can’t pass exact value of num_features in TLU while replacing relu with tlu,how to modify this line of code :


tlu = TLU(num_features = relu.num_features)

for changing all relu with tlu?