Pruning at initialization does not increase efficiency

Hello everyone,

I recently started exploring pruning methods for Deep Neural Networks and stumbled on some interesting papers suggesting algorithms for unstructured pruning at initialization (e.g. SNIP). Unlike structured pruning, unstructured pruning removes weights instead of neurons, i.e. single matrix elements are pruned instead of full rows/columns (depending on notation). Several publications state that unstructured pruning leads to higher compression ratios.

Essentially all research papers I read state that pruning leads to gains in efficiency and memory requirements, which I cannot verify from experimenting with pruning in PyTorch. For example, when using ResNet50 with pretrained weights on CIFAR10 and retraining the model for 20 epochs, I get:

  • M = 27.8995 (SD = 1.1458) seconds per training epoch for the full model

  • M = 28.7073 (SD = 0.1553) seconds per training epoch for the pruned model (60% of weights in linear and conv layers pruned)

  • M = 11.2734 (SD = 0.3740) seconds per evaluation of the test data for the full model (20 iterations)

  • M = 12.1780 (SD = 0.1735) seconds per evaluation of the test data for the pruned model (20 iterations)

The pruned model is significantly slower for both training (t = 3.04, p < .001) and inference (t = -9.5652, p < .001), which is the exact opposite of what is stated in the papers. I trained on Google Colab, so I assume the issue is not confounded with background tasks. For reference, please see the minimal working example (IPYNB). The same issue is present for ResNet18 as well, even though less pronounced as the model is smaller/faster overall.

From my understanding, this behavior is expected, as the pruning is implemented via a binary mask that is applied in each forward pass (the pruned weights appear to still be updated in the backward pass). I also tried the remove method, but it is not suitable for pruning before training, as the pruned weights are subsequently treated as regular weights with an initialization of 0 and updated throughout training.

However, from the papers I read, it sounded like the implementation leads to some sort of sparsity structure that improves the efficiency of forward and backward passes. After doing some research, it appears as if sparse training is not currently supported in PyTorch, and sparse inference appears to not be possible with ResNets either (at least in torch == 1.13 with cuda, due to Conv2d layers).

In summary: None of the methods I tried (which are more than provided in the minimal example) lead to an increase in efficiency, neither for training nor for inference. Notably, while the papers I read state that pruning is beneficial, none of them report any measures of efficiency to prove their assumptions, which adds to my confusion. So my question is:

Is there an actual, practical benefit from pruning at initialization or are the assumptions purely theoretical?

Please note that I am still fairly new to PyTorch, having only used Keras before. It might absolutely be possible that I just messed something up in my code. So all in all, any help is greatly appreciated!

Thanks and happy holidays,

Note: As this appears to be a bigger issue for pruning, I also asked this question on StackOverflow.

You should be able to verify that the pruning functions are doing something by e.g., printing out some of the weight values before/after pruning and checking that some of them are indeed being set to zero. However, by default just applying such pruning would not result in any speedup without further modifications to the model, as the weights are simply replaced with zeros—performing an operation like matrix multiplication of convolution with zero’d weights won’t be more efficient in the common case when using dense implementations of the operation that have no knowledge of the underlying sparsity.

You can verify this by e.g., testing the model with fully zero’d weights and checking the relative performance.

For matmul operations, you may some speedups when using sparse tensors instead e.g., torch.sparse — PyTorch 1.13 documentation, and for convolutions you could take a look at applying structured sparsity to prune channels and creating a new dense layer from the sparse output.