How to extract features of an image from a trained model

@alexis-jacq I wouldn’t recommend that. It’s better to keep your models stateless i.e. not hold any of the intermediate states. Otherwise, if you don’t pay enough attention to them, you might end up with problems when you’ll have references to the graphs you don’t need, and they will be only taking up memory.

If you really want to do something like that, I’d recommend this:

class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        self.submodule = submodule

    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
            x = module(x)
            if name in self.extracted_layers:
                outputs += [x]
        return outputs + [x]

This unfortunately uses a private member _modules, but I don’t expect it to change in the near future, and we’ll probably expose an API for iterating over modules with names soon.

18 Likes