Could you check your current PyTorch version? If I’m not mistaken, torch.nn.utils.prune was introduced in 1.4.0, so you would need this or a newer version.
All the inbuilt methods for the prune module are based on norm of parameters (weights and biases). Any suggestion if I want to prune connections based on activations (e.g. remove top 10% connections based on norm of activations)? Should I try implementing it by subclassing or do you know if it has been implemented someplace else?
Also is it easy to switch the ranking of torch.nn.utils.prune.l1_unstructured? It says “removing the specified amount of (currently unpruned) units with the lowest L1-norm”. What if I want to remove the ones with the highest L1-norm? Let’s assume I have an application for the opposite ranking. Not sure if it makes sense according to pruning literature,
I haven’t implemented custom pruning methods so far.
However, @Michela presents the pruning module as well as ways to customize it in this talk, which might be a good starter.
I can help with the implementation, whether it ends up making it into the official list of pytorch supported methods, or it ends up living in your own separate repository.
Either way, I’d recommend subclassing BasePruningMethod so that you can take advantage of all the pre-packaged functionalities provided by the base class, and remain compatible with the pruning “language” of PyTorch pruning. The bottom portion of the pruning tutorial shows how to extend the module with your own custom pruning module (through a very simple example, in that case).