Hi all,
So I am trying to add a custom global pruning method by extending PyTorch pruning class. So I already have the mask for each layer stored in the dict where key is the module name like conv1 and the value is the masks for the specific module. But I want to use the torch class so that in future when I sometime the model then also those weights remain 0.
Any idea how can I do it. It will be a great help.
Hi all,
So I am trying to add a custom global pruning method by extending PyTorch pruning class. So I already have the mask for each layer stored in the dict where key is the module name like conv1 and the value is the masks for the specific module. But I want to use the torch class so that in future when I sometime the model then also those weights remain 0 learn more
Any idea how can I do it. It will be a great help.
To achieve this, you can extend the PyTorch pruning class and use the torch.nn.utils.prune
module. Here’s a high-level approach:
- Create a Custom Pruning Class:
- Extend the
PruningMethod
class fromtorch.nn.utils.prune
. - Implement your custom pruning logic in the
compute_mask
method.
- Apply the Custom Pruning Method:
- Use the
prune
module to apply your custom pruning method to the model parameters.
Here’s a basic example to get you started:
import torch
import torch.nn.utils.prune as prune
class CustomPruning(prune.PruningMethod):
def compute_mask(self, t, default_mask):
# Your custom logic to compute the mask
mask = torch.zeros_like(t)
# Example: Set mask to 0 for weights below a certain threshold
mask[t.abs() < 0.1] = 1
return mask
# Example model
model = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 1)
)
# Apply custom pruning
parameters_to_prune = (model[0], 'weight')
prune.global_unstructured(parameters_to_prune, pruning_method=CustomPruning(), amount=0.2)
This example demonstrates how to create a custom pruning class and apply it to a model. You can adjust the compute_mask
method to use your stored masks and ensure the weights remain zero when reloading the model.
Hey thanks, I done something like this - I have one more question:
I pruned the model that has now weight_masks and weight_orig and an attribute weight. So during finetuning I want to keep the pruning such that only those weights are updated which are not 0. So I found 2 ways of doing it, correct me if I am wrong -
- Finetune the pruned model directly, here the question is are weight_orig get updated or weight(attribute)? And if weight_orig updates then the final weight(parameter) is something like weight=weight_orig*weight_mask. And when I save this model then I have weight_orig and weight_mask in the state dict, so once I want to use this fine-tuned pruned model again then in that case I will load the model and do .weight as weight_orig*weight_mask.
- After fine-tuning should I apply Reparametrization? so when I am loading the saved model again then I only get pruned weights in my .weight