What is the right way to apply pruning to LSTMs?

Hello everyone!

I’m currently studying model optimization methods and trying out pruning right now. I’ve chosen CRNN for my experiments because it has different types of layers following each other.

Net architecture, if this helps

I prune layers by iteratively selecting channel indices with the least L1 norm:

  1. Calculate the L1 norm channel-wise.
  2. Select n% least L1 norm channel indices.
  3. Modify the current layer’s weights by removing OUT channels with indices from step 2.
  4. Modify the next layer’s weights by removing IN channels with indices from step 2.
  5. Loop over every layer.

This scheme works correctly with the convolutional part of the net, though when I apply it to the whole net, it starts to produce nans. My assumption here is that something is wrong with my LSTM weights processing (nans occur right after the LSTM). However, I’m able to correctly prune the first LSTM layer’s IN channels (so there is no shape conflict).

Here’s part of my code that processes the output part of LSTM weights (output for ih and input and output for hh weights, actually).

elif isinstance(layer, nn.LSTM):
  prev_layer_output_size = layer.hidden_size

  for attr in dir(layer):
    if attr.startswith('weight_ih'):
      weight = getattr(layer, attr).view(layer.hidden_size, 4, -1)
      l1 = torch.sum(torch.abs(weight), dim=(1, 2))
      sorted_indices = torch.argsort(l1)[int(fraction*layer.hidden_size):]
      sorted_indices = torch.sort(sorted_indices).values

      bias_name = attr.replace('weight', 'bias')
      setattr(layer, attr, torch.nn.Parameter(weight[sorted_indices].view(-1, weight.shape[-1])))
      setattr(layer, bias_name, torch.nn.Parameter(getattr(layer, bias_name)[sorted_indices]))

  for attr in dir(layer):
    if attr.startswith('weight_hh'):
      weight = getattr(layer, attr)[:, sorted_indices].view(layer.hidden_size, 4, -1)
      l1 = torch.sum(torch.abs(weight), dim=(1, 2))
      sorted_indices = torch.argsort(l1)[int(fraction*layer.hidden_size):]
      sorted_indices = torch.sort(sorted_indices).values

      bias_name = attr.replace('weight', 'bias')
      setattr(layer, attr, torch.nn.Parameter(weight[sorted_indices].view(-1, weight.shape[-1])))
      setattr(layer, bias_name, torch.nn.Parameter(getattr(layer, bias_name)[sorted_indices]))

  layer.hidden_size = len(sorted_indices)

I’ve read the docs and LSTM impl, though, wasn’t able to fix the error.
Finally, some observations that might help:

  1. I’ve checked model weights for nans, and there weren’t any.
  2. The model produces nans randomly. I use torchinfo.summary(…) or my own validation function to check for nans, so there could be a lot of randomness in the data. Some torchinfo.summary(…) or validation iterations didn’t return nans for some reason.

Thanks in advance. I would really appreciate your help.

Hello again.

I decided to give pruning another chance and found the error in the question code:

bias_name = attr.replace('weight', 'bias')
setattr(layer, attr, torch.nn.Parameter(weight[sorted_indices].view(-1, weight.shape[-1])))
setattr(layer, bias_name, torch.nn.Parameter(getattr(layer, bias_name)[sorted_indices]))  # <--- HERE

When reindexing bias, I forgot that bias has to be chunked before reindexing. Bias has (hidden_size*4,) shape, while my sorted_indices variable was meant to index tensors of shape (hidden_size,).

So I simply chunk the bias tensor ((hidden_size*4,) → (hidden_size, 4)), reindex it over 0th dim, and reshape it back to 1d tesnor:

bias_name = attr.replace('weight', 'bias')
setattr(layer, attr, torch.nn.Parameter(weight[sorted_indices].view(-1, weight.shape[-1])))
setattr(layer, bias_name, torch.nn.Parameter(getattr(layer, bias_name).view(-1, 4)[sorted_indices].view(-1)))

And that’s it; no nans were produced.

Though my pruning RNNs pruning method is completely incorrect: pruning the hidden size of RNN stresses the net so much that it loses all of its capacity, which results in ~20x loss rise (and ~20x metric drop). So net has to be totally retrained. For example, when I pruned only the convolution part, there wasn’t any significant metric drop (though loss rose ~1.5x).

1 Like