What is the right way to apply pruning to LSTMs?

Hello again.

I decided to give pruning another chance and found the error in the question code:

bias_name = attr.replace('weight', 'bias')
setattr(layer, attr, torch.nn.Parameter(weight[sorted_indices].view(-1, weight.shape[-1])))
setattr(layer, bias_name, torch.nn.Parameter(getattr(layer, bias_name)[sorted_indices]))  # <--- HERE

When reindexing bias, I forgot that bias has to be chunked before reindexing. Bias has (hidden_size*4,) shape, while my sorted_indices variable was meant to index tensors of shape (hidden_size,).

So I simply chunk the bias tensor ((hidden_size*4,) → (hidden_size, 4)), reindex it over 0th dim, and reshape it back to 1d tesnor:

bias_name = attr.replace('weight', 'bias')
setattr(layer, attr, torch.nn.Parameter(weight[sorted_indices].view(-1, weight.shape[-1])))
setattr(layer, bias_name, torch.nn.Parameter(getattr(layer, bias_name).view(-1, 4)[sorted_indices].view(-1)))

And that’s it; no nans were produced.

Though my pruning RNNs pruning method is completely incorrect: pruning the hidden size of RNN stresses the net so much that it loses all of its capacity, which results in ~20x loss rise (and ~20x metric drop). So net has to be totally retrained. For example, when I pruned only the convolution part, there wasn’t any significant metric drop (though loss rose ~1.5x).

1 Like