I am pruning a model where I am pruning the model to make it 40% sparse. So to do this I am doing iteratively where my pruning_percent = 40/Num of batches and I am pruning the model every mini-batch by the pruning_percent. For example if I just have 2 mini-batches then I am pruning 20% weights in the first mini batch and rest of the 20% in the next mini batch.
So to achieve this I first pruned the model with mask as all ones so that I can have weight_orig and weight_mask. And then I am running this iterative pruning with my pruning logic and based on the threshold I am pruning the weights and updating the layer.weight_mask = mask[layer]. This pruned the model to 40% but when I am running my evaluation method on this pruned model then I get NaN outputs.
Is it because I am somewhere messing up with the grads while pruning?
Any thoughts on this?