Weight Pruning on BERT

Hello, I used torch.nn.utils.prune to apply weight pruning on BERT. It generates sparse matrices as expected. However, the model’s size doesn’t decrease and the inference speed doesn’t increase? How can I speed up after pruning and actually compress the model’s size? Thank you!

BERT_QA = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
for child in BERT_QA.modules():
    if isinstance(child, nn.Linear):
        prune.l1_unstructured(child, name='weight', amount=0.2)
        prune.l1_unstructured(child, name='bias', amount=0.2)
        prune.remove(child, 'weight')
        prune.remove(child,'bias') 
1 Like

The point of PyTorch pruning, at the moment, is not necessarily to guarantee inference time speedups or memory savings. It’s more of an experimental feature to enable pruning research.

In any case, answers to questions similar to yours were given here and here.

TL;DR: You can save space by calling .to_sparse() which brings your sparse tensor into coordinate representation. You cannot expect any inference speedups unless you use a custom sparse matrix algebra library to power your computation. torch.sparse is still a work in progress for now. Otherwise, for now, you’ll just be doing the same number of operations as you did before pruning, only now with a bunch of entries equal to zero in your tensors.

1 Like

Hello Michela,
sorry to reopen this thread.
Is there any update in applying pruning and reducing inference time in doing so? If not, which library you would recommend to perform this computations in an efficient manner?

Thanks!

Best, Andreas

2 Likes