@ post above me
You could iterate through the whole network. Think of it as a tree.
all_layers = []
def remove_sequential(network):
for layer in network.children():
if type(layer) == nn.Sequential: # if sequential layer, apply recursively to layers in sequential layer
remove_sequential(layer)
if list(layer.children()) == []: # if leaf node, add it to list
all_layers.append(layer)