How to improve inference time of pruned model using torch.nn.utils.prune?

Hi ,

I have used torch.nn.utils.prune to prune a pretrained model using below code.

 for name, module in model.named_modules():
     # prune 20% of connections in all 2D-conv layers
     if isinstance(module, torch.nn.Conv2d):
         prune.l1_unstructured(module, name='weight', amount=0.3)
         prune.remove(module, 'weight')
     # prune 40% of connections in all linear layers
     elif isinstance(module, torch.nn.Linear):
         prune.l1_unstructured(module, name='weight', amount=0.4)
         prune.remove(module, 'weight')

I have saved the pruned model. This pruned model’s size is halved and accuracy took a hit but did not improve its inference time. Please let me know how to improve the inference time of pruned model

Thanks in advance.

2 Likes

@DaVenBi can you tell me how you saved the model ?

Based on sparsity being applied in an unstructured manner, I don’t think there is an out-of-the-box solution that speeds up inference based on the fact that the matrices are more sparse. I suggest that you try pruning in a structured way and see how that changes inference speed.

1 Like

As @ypsoh mentioned, the torch.sparse module is still in beta. So you won’t get any speed up for a model pruned in an unstructured way unless you make use of sparse matrix operations. Hopefully torch.sparse will be part of a stable release soon.

1 Like