Purpose of pytorch pruning

I have seen and also tried pytorch pruning, but have not seen any improvements in inference time or reduction of model size. By considering the old posts i have seen pruning is like a expermental feature. Is it still like expermental tool or is there any real impact of inference time.

Hi, I did some reaserach on pruning and might help you clear some doubts.

Pruning by itself only sets to zero some elements of the weight tensor, but such elements are still part of the nertwork. This is the reason why you do not see any difference in the model size nor in the inference time.

That said there are ways to exploit a pruned model to gain some advantages depending on the pruning strategies that was used.
If you use an unstructured pruning approach (leads to tensors that have zeroes peppered throughout), you could rely on some ad hoc hardware or sotware (i.e. torch.sparse — PyTorch 1.11.0 documentation) which should be optimized to work with sparse tensors.
If you use a structured approach (removing entire rows or columns of the weight matrices) you might be able to remove the pruned neurons, leading to smalle and faster models without the need for particular hardware.

For the second case we’ve actually built a PyTorch-compatible library which does exactly that. You are welcome to try it out.

If you have any questions feel free to ask.

2 Likes

Thanks for your answer, i really appreciate the work you put on this. Sure will check the simplify. I am unstructured pruning hence lot of my weights are set to zero. Does this library support transfromers in general??

Our library is currently tested on standard torchvision classifcation models, up to v0.10 (I should probably specify it in the docs since now they added more models).
If you’d like to test it on your specific case and open an issue for any eventual problem, we’ll be happy to assist with it and possibly improve the library.

Just notice that unstructured pruning may not lead to any change since it does not ensure that entire neurons are pruned. In pytorch, for standard Conv2d or Linear, neurons are mapped by the 0 dimension of the weight tensor, so you might want to check if you have some of them which are completely zeroed-out.

Here a naive way to do this:

layer = Linear(500, 10)
idx = torch.tensor([0, 2, 5])
with torch.no_grad():
    layer.weight[idx] = layer.weight[idx] * 0
print(torch.where(layer.weight.abs().sum(dim=1) == 0)[0])  # dim=(1,2,3) for Conv2d
1 Like

Hi! Can the simplify library handle dependecies among layers such as skip connections and concatenations? Thanks!

Hello, sorry for the late reply.
Yes we did handle residual connection (with the sum operation) as they are implemented in the ResNet models contained in torchvision.

1 Like