I want to prune after training and then finetune the pruned model. If I use the
torch.nn.utils.prune library, as far as I understand it, during the forward pass the weights of a layer will first be zeroed using the pruning mask (via pre forward hook). This however makes the masking part of the backward step and it will have an effect on the actual gradient updates.
What I want to do is the following:
I want to prune a model and then continue training by just ignoring the pruned weights, similar to as they would have been removed. The mask works correctly in the forward pass as all pruned weights are set to 0, but wouldn’t I get different gradients when doing backpropagation? How can I do this within the pytorch library?
Thanks a lot!