Alter model layers of VGG and rebuild module structure

Hey everyone,

I want to optimize a vanilla VGG model after training and delete unnecessary feature maps of BatchNorm2d (BN) and Conv2d (C2D) layers.

To achieve that I first flatten all layers of the vanilla VGG, get all the BN and C2D layers, modify them and put those flattened layers back into a nn.Sequential() as my new model.

This fails because in vanilla VGG model the forward pass looks like:

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1) # -> flattening tensor does not happen when using nn.Sequential(), shape mismatch
        x = self.classifier(x)
        return x

My question now is:
How can I maintain the model structure when altering the layers? My algorithm currently relies on two lists of all BN and all C2D layers. Is there something like a unique ID a layer carries inside the model structure so that I can just replace it with something like:

model.modules['uniqueID'] = altered_layer

If there are no uniqueIDs can u imagine a way to somehow build these myself (encoding module structure and decoding it)?

I could just make it work and write my own model with this specific forward pass, but I consider replacing the layers to be more elegant.

Help is appreciated!


Meanwhile I tinkered some code which works:

import functools
import torch
def get_module_layers_and_keychains(module):
    """ Returns all layers of a pytorch module and a keychain as identifier.
        (['features', '0', '5'], nn.ReLU())
        (['classifier', '0'], nn.BatchNorm2D())
        (['classifier', '1'], nn.Linear())
    flattened = []

    o_dict = module._modules
    for key, child in o_dict.items():
        if not child._modules:
            structure = [([key], child)]
            structure = get_module_layers_and_structure(child)
            structure = [([key] + keylist, child) for keylist, child in structure]


    return flattened
def replace_module_layer(module, layer_keychain, new_layer):
    """Replaces a layer inside a pytorch module for a given keychain.
       Use get_module_layers_and_keychains() to retrieve valid keychains.
    get_fn = lambda self, key: self._modules[key]
    l_root = functools.reduce(get_fn, layer_keychain[:-1], module)
    l_root._modules[layer_keychain[-1]] = new_layer

I am still interested in your thoughts and ideas!