Hello everyone,
I recently started exploring pruning methods for Deep Neural Networks and stumbled on some interesting papers suggesting algorithms for unstructured pruning at initialization (e.g. SNIP). Unlike structured pruning, unstructured pruning removes weights instead of neurons, i.e. single matrix elements are pruned instead of full rows/columns (depending on notation). Several publications state that unstructured pruning leads to higher compression ratios.
Essentially all research papers I read state that pruning leads to gains in efficiency and memory requirements, which I cannot verify from experimenting with pruning in PyTorch. For example, when using ResNet50 with pretrained weights on CIFAR10 and retraining the model for 20 epochs, I get:

M = 27.8995 (SD = 1.1458) seconds per training epoch for the full model

M = 28.7073 (SD = 0.1553) seconds per training epoch for the pruned model (60% of weights in
linear
andconv
layers pruned) 
M = 11.2734 (SD = 0.3740) seconds per evaluation of the test data for the full model (20 iterations)

M = 12.1780 (SD = 0.1735) seconds per evaluation of the test data for the pruned model (20 iterations)
The pruned model is significantly slower for both training (t = 3.04, p < .001) and inference (t = 9.5652, p < .001), which is the exact opposite of what is stated in the papers. I trained on Google Colab, so I assume the issue is not confounded with background tasks. For reference, please see the minimal working example (IPYNB). The same issue is present for ResNet18 as well, even though less pronounced as the model is smaller/faster overall.
From my understanding, this behavior is expected, as the pruning is implemented via a binary mask that is applied in each forward pass (the pruned weights appear to still be updated in the backward pass). I also tried the remove
method, but it is not suitable for pruning before training, as the pruned weights are subsequently treated as regular weights with an initialization of 0 and updated throughout training.
However, from the papers I read, it sounded like the implementation leads to some sort of sparsity structure that improves the efficiency of forward and backward passes. After doing some research, it appears as if sparse training is not currently supported in PyTorch, and sparse inference appears to not be possible with ResNets either (at least in torch == 1.13 with cuda, due to Conv2d layers).
In summary: None of the methods I tried (which are more than provided in the minimal example) lead to an increase in efficiency, neither for training nor for inference. Notably, while the papers I read state that pruning is beneficial, none of them report any measures of efficiency to prove their assumptions, which adds to my confusion. So my question is:
Is there an actual, practical benefit from pruning at initialization or are the assumptions purely theoretical?
Please note that I am still fairly new to PyTorch, having only used Keras before. It might absolutely be possible that I just messed something up in my code. So all in all, any help is greatly appreciated!
Thanks and happy holidays,
David
Note: As this appears to be a bigger issue for pruning, I also asked this question on StackOverflow.