Finetuning a model after pruning - Autograd question

I want to prune after training and then finetune the pruned model. If I use the torch.nn.utils.prune library, as far as I understand it, during the forward pass the weights of a layer will first be zeroed using the pruning mask (via pre forward hook). This however makes the masking part of the backward step and it will have an effect on the actual gradient updates.

What I want to do is the following:
I want to prune a model and then continue training by just ignoring the pruned weights, similar to as they would have been removed. The mask works correctly in the forward pass as all pruned weights are set to 0, but wouldn’t I get different gradients when doing backpropagation? How can I do this within the pytorch library?

Thanks a lot!

Hi Max,
Have you solved this problem? I got the same problem even I trying to mask those pruned weight in every iteration. I hope we can discuss on this problem if you have any idea.

Thank you!

Hey, yes actually the prune module behaves just as I wanted it to behave. If you prune a module, than the weights will be zero in the forward pass and the gradients will be zero as well.

Does that answer your question?

Hey Max,
so reading your reply, I understand that if we do not want to update pruned weights during fine-tuning the pruning module takes care of that?
So gradients of pruned weights are zero by default?

Yes, exactly. Gradients of pruned weights will be zero (I experimentally verified this).

Hey @mzimmer ,
again I would appreciate your help as I am trying to prune iteratively right now:
From my Conv modules, I pruned the weights using nn.utils.prune.
Now I want to finetune the model again - more specifically only the Conv modules, as the Linear modules were not pruned. Which parameters do I pass to the optimizer? weight_orig or weights?

Because model.named_parameters() returns only weight_orig and bias of each module.
The pruned weight is now accessible directly by calling module.weight.

So which parameters do I need to pass to the optimizer when I want to finetune the model again? Because I want pruned weights (which are basically just weights with the value zero) to stay zero.
My intuition says the pruned weights need to be passed now - how can I do that easily?
Maybe removing the weight_orig of the module using the remove function of the nn.utils.prune library?

Happy to hear your thoughts - maybe you can help me out.

UPDATE: I tried removing the weight_orig of the module using the remove function of the nn.utils.prune. Works in theory, but weights that were pruned (= set to zero) were updated during training again - resulting in non-zero weights again. So this cannot be the solution.

Hey @davidweb,
so I am not a 100% sure on this, but as soon as you call remove(), the zero weights will get updates again.

Have you tried the following:

  1. You should just pass model.parameters() to the optimizer, as usual.
  2. If I understand you correctly, you don’t want the Linear modules to be updated during finetuning, right? Just continue training the full model, but freeze the layers that you don’t want to be updated, check for example this link.

I think this should do what you want. As long as you do not remove the pruning, pruned weights won’t get updated. Freezing allows you to disable certain layers of your network. Let me know whether this fixes your issue.

Hey @mzimmer,

yes exactly, freezing FC Layers during finetuning of a pruned model is what I wanted.
Using your suggestion worked perfectly - after several iterations zero weights stay zero.

Thanks for helping me out! Very much appreciated.

1 Like