Finetuning a model after pruning - Autograd question

I want to prune after training and then finetune the pruned model. If I use the torch.nn.utils.prune library, as far as I understand it, during the forward pass the weights of a layer will first be zeroed using the pruning mask (via pre forward hook). This however makes the masking part of the backward step and it will have an effect on the actual gradient updates.

What I want to do is the following:
I want to prune a model and then continue training by just ignoring the pruned weights, similar to as they would have been removed. The mask works correctly in the forward pass as all pruned weights are set to 0, but wouldn’t I get different gradients when doing backpropagation? How can I do this within the pytorch library?

Thanks a lot!

Hi Max,
Have you solved this problem? I got the same problem even I trying to mask those pruned weight in every iteration. I hope we can discuss on this problem if you have any idea.

Thank you!

Hey, yes actually the prune module behaves just as I wanted it to behave. If you prune a module, than the weights will be zero in the forward pass and the gradients will be zero as well.

Does that answer your question?
Best,
Max

Hey Max,
so reading your reply, I understand that if we do not want to update pruned weights during fine-tuning the pruning module takes care of that?
So gradients of pruned weights are zero by default?

Yes, exactly. Gradients of pruned weights will be zero (I experimentally verified this).
Best,
Max

Hey @mzimmer ,
again I would appreciate your help as I am trying to prune iteratively right now:
From my Conv modules, I pruned the weights using nn.utils.prune.
Now I want to finetune the model again - more specifically only the Conv modules, as the Linear modules were not pruned. Which parameters do I pass to the optimizer? weight_orig or weights?

Because model.named_parameters() returns only weight_orig and bias of each module.
The pruned weight is now accessible directly by calling module.weight.

So which parameters do I need to pass to the optimizer when I want to finetune the model again? Because I want pruned weights (which are basically just weights with the value zero) to stay zero.
My intuition says the pruned weights need to be passed now - how can I do that easily?
Maybe removing the weight_orig of the module using the remove function of the nn.utils.prune library?

Happy to hear your thoughts - maybe you can help me out.
Thanks!

UPDATE: I tried removing the weight_orig of the module using the remove function of the nn.utils.prune. Works in theory, but weights that were pruned (= set to zero) were updated during training again - resulting in non-zero weights again. So this cannot be the solution.

Hey @davidweb,
so I am not a 100% sure on this, but as soon as you call remove(), the zero weights will get updates again.

Have you tried the following:

  1. You should just pass model.parameters() to the optimizer, as usual.
  2. If I understand you correctly, you don’t want the Linear modules to be updated during finetuning, right? Just continue training the full model, but freeze the layers that you don’t want to be updated, check for example this link.

I think this should do what you want. As long as you do not remove the pruning, pruned weights won’t get updated. Freezing allows you to disable certain layers of your network. Let me know whether this fixes your issue.

Hey @mzimmer,

yes exactly, freezing FC Layers during finetuning of a pruned model is what I wanted.
Using your suggestion worked perfectly - after several iterations zero weights stay zero.

Thanks for helping me out! Very much appreciated.

1 Like

Hi Max,@mzimmer
Is the size of the pruned model (due to weight-mask storage) larger than the original model? I pruned the encoder layers of my model and after saving I noticed that the model size got bigger.

‌Best
Thank you

Hi @Hajar_Mazaheri ,
yes, this will increase the storage requirements, since for every pruned tensor you will need to store the mask as well.

1 Like

Thank you for your response, @mzimmer
In your opinion, is pruning the weights one of the compression techniques, or is the model lightweight? And do pruning the weights reduce computational complexity(MAC) during training and inference?
I observed that after pruning the weights and retraining the model, the results improved.

I would be happy to know your opinion and experience on this matter.

I am not sure what you mean by pruning being one of the compression techniques or the model being lightweight. Could you elaborate?

Whether pruning reduces the number of FLOPs or MACs depends heavily on a number of things, e.g., when it is applied (post-training or during/before training), whether the sparsity is structured (removing entire filters or just weights without a clear pattern in the tensor), the device used (gpu or cpu), and many others. I recommend reading into the survey by Hoefler et al.: [2102.00554] Sparsity in Deep Learning: Pruning and growth for efficient inference and training in neural networks

In my experience, retraining can drastically improve the performance of the model. If this interests you, one of my papers on retraining might clarify some terminology in this regard: [2111.00843] How I Learned to Stop Worrying and Love Retraining

Best,
Max

1 Like

By compression methods, I mean methods that can help to have a lighter model for inference applications, especially for edge devices. I explored methods such as quantization and knowledge distillation. And now I am investigating the pruning method for my model because in terms of optimization methods, pruning is one of the most effective techniques.
And I want to know that zeroing the weights in addition to improving the results can increase the speed of inference (or training) in the model considering that the size of the model is also increased.

Thank you for your guidance
Your articles can be a good guide for me in this field

Pruning can make your model more efficient, but as said, it also depends on the type of pruning. If you prune in an unstructured way (as magnitude pruning would do), then this sparsity is in general not directly exploitable. More common approaches that result in direct speedups are structured pruning strategies, such as convolutional filter pruning, pruning of entire neurons, or semi-structured sparsities like 2:4 or 4:8 sparsities. Hope that this answers the question, otherwise feel free to check out the mentioned articles for more details!

1 Like

Yes, exactly. @mzimmer
Your detailed explanation clarified the nuances of pruning techniques, highlighting the importance of structured approaches for achieving direct speedups. I appreciate your insights into the distinctions between unstructured and structured pruning methods, such as convolutional filter pruning and sparsities like 2:4 or 4:8.

Thank you for your expertise and assistance.

Hi Max @mzimmer
Your suggested articles give me a lot of information about pruning. Considering that you have a lot of information in this field, I have one more question for you.
When we use unstructured (as magnitude pruning) pruning, the zeroed weights are updated during training, and after retraining(fine-tuning), the zeroed weights have a new value, and our model structure is not a sparse structure. Is it true?
Do you think this pruning can still improve model performance if the zeroed weights are frozen during training?

Thank you!

I found the answer to my question. It is related to differentiate between static and dynamic sparsity during training.
Dynamic sparsity during training starts with schemes that iteratively prune and add (regrow) elements during the training phase. but Fixed sparsity during training, this structure can be hand-tuned such as “structured sparsity” for transformers. :blush: