Do you mean something like this?
def dfs(module:nn.Module, modify_module:Function, is_root=True) -> nn.Module:
i = -1
for i, (child_name, child_module) in enumerate(module.named_children()):
dfs(child_module, modify_module, is_root=False)
is_leaf = i == -1
modify_module(module, is_root, is_leaf)
return module