Custom pruning method for Global Unstructured Pruning

Hi all,
So I am trying to add a custom global pruning method by extending PyTorch pruning class. So I already have the mask for each layer stored in the dict where key is the module name like conv1 and the value is the masks for the specific module. But I want to use the torch class so that in future when I sometime the model then also those weights remain 0.
Any idea how can I do it. It will be a great help.

Hi all,
So I am trying to add a custom global pruning method by extending PyTorch pruning class. So I already have the mask for each layer stored in the dict where key is the module name like conv1 and the value is the masks for the specific module. But I want to use the torch class so that in future when I sometime the model then also those weights remain 0 learn more
Any idea how can I do it. It will be a great help.

To achieve this, you can extend the PyTorch pruning class and use the torch.nn.utils.prune module. Here’s a high-level approach:

  1. Create a Custom Pruning Class:
  • Extend the PruningMethod class from torch.nn.utils.prune.
  • Implement your custom pruning logic in the compute_mask method.
  1. Apply the Custom Pruning Method:
  • Use the prune module to apply your custom pruning method to the model parameters.

Here’s a basic example to get you started:

import torch
import torch.nn.utils.prune as prune

class CustomPruning(prune.PruningMethod):
    def compute_mask(self, t, default_mask):
        # Your custom logic to compute the mask
        mask = torch.zeros_like(t)
        # Example: Set mask to 0 for weights below a certain threshold
        mask[t.abs() < 0.1] = 1
        return mask

# Example model
model = torch.nn.Sequential(
    torch.nn.Linear(10, 5),
    torch.nn.ReLU(),
    torch.nn.Linear(5, 1)
)

# Apply custom pruning
parameters_to_prune = (model[0], 'weight')
prune.global_unstructured(parameters_to_prune, pruning_method=CustomPruning(), amount=0.2)

This example demonstrates how to create a custom pruning class and apply it to a model. You can adjust the compute_mask method to use your stored masks and ensure the weights remain zero when reloading the model.

Hey thanks, I done something like this - I have one more question:

I pruned the model that has now weight_masks and weight_orig and an attribute weight. So during finetuning I want to keep the pruning such that only those weights are updated which are not 0. So I found 2 ways of doing it, correct me if I am wrong -

  1. Finetune the pruned model directly, here the question is are weight_orig get updated or weight(attribute)? And if weight_orig updates then the final weight(parameter) is something like weight=weight_orig*weight_mask. And when I save this model then I have weight_orig and weight_mask in the state dict, so once I want to use this fine-tuned pruned model again then in that case I will load the model and do .weight as weight_orig*weight_mask.
  2. After fine-tuning should I apply Reparametrization? so when I am loading the saved model again then I only get pruned weights in my .weight