Module.children() vs Module.modules()

@ post above me
You could iterate through the whole network. Think of it as a tree.

all_layers = []
def remove_sequential(network):
    for layer in network.children():
        if type(layer) == nn.Sequential: # if sequential layer, apply recursively to layers in sequential layer
            remove_sequential(layer)
        if list(layer.children()) == []: # if leaf node, add it to list
            all_layers.append(layer)
13 Likes