Hello again.
I decided to give pruning another chance and found the error in the question code:
bias_name = attr.replace('weight', 'bias')
setattr(layer, attr, torch.nn.Parameter(weight[sorted_indices].view(-1, weight.shape[-1])))
setattr(layer, bias_name, torch.nn.Parameter(getattr(layer, bias_name)[sorted_indices])) # <--- HERE
When reindexing bias, I forgot that bias has to be chunked before reindexing. Bias has (hidden_size*4,) shape, while my sorted_indices variable was meant to index tensors of shape (hidden_size,).
So I simply chunk the bias tensor ((hidden_size*4,) → (hidden_size, 4)), reindex it over 0th dim, and reshape it back to 1d tesnor:
bias_name = attr.replace('weight', 'bias')
setattr(layer, attr, torch.nn.Parameter(weight[sorted_indices].view(-1, weight.shape[-1])))
setattr(layer, bias_name, torch.nn.Parameter(getattr(layer, bias_name).view(-1, 4)[sorted_indices].view(-1)))
And that’s it; no nans were produced.
Though my pruning RNNs pruning method is completely incorrect: pruning the hidden size of RNN stresses the net so much that it loses all of its capacity, which results in ~20x loss rise (and ~20x metric drop). So net has to be totally retrained. For example, when I pruned only the convolution part, there wasn’t any significant metric drop (though loss rose ~1.5x).