How to apply pruning to a model with weight normalization?

I have trained a model for waveform synthesis and would now like to apply global pruning using an iterative pruning schedule. However, the model uses weight normalization in most of its layers, which means the weights are split into two components, weight_g and weight_v, from which the actual weights are computed on each forward pass. It doesn’t make sense to apply pruning on either weight_g or weight_v only, so I need to find some way to temporarily get the un-normalized weights back.

So far I have considered two approaches:

  1. For each pruning step, disable weight normalization, conduct the pruning, and re-enable weight normalization. This doesn’t work because pruning will create a parameter weight_orig while after re-enabling weight normalization there will be no weight anymore. Also re-enabling weight normalization seems to create a new set of parameters which are not tracked by the optimizer.
  2. For each pruning step, compute the actual weights by calling the _weight_norm function, replace weight_v with the result, conduct the pruning and undo the swapping. The problem here is that the pruning function will write into weight_v the values of weight with the pruning mask applied, which means you can’t just swap the former values back in.

How could I go about this? Could I make one of the described approaches work in an elegant way or is there another possibility that I didn’t think of?