How to prune weights less than a threshold in PyTorch?

How to prune weights of a CNN (convolution neural network) model which is less than a threshold value (let’s consider prune all weights which are <= 1). I want to prune less significant weights so that accuracy won’t be degraded.

How we can achieve that for a weight file saved in .pth format in PyTorch?

You can implement it as an extension of torch.nn.utils.prune. You can find instructions on how to create your custom pruning function by following the pytorch pruning tutorial:

In practice, in compute_mask, among other things, you will have to generate a mask that preserves all the entries in the tensor t that are above your threshold: mask = (t > 1.).to(t.dtype)