Global-local structured pruning

Hello,

I’m struggling with the implementation of a global-local structured pruning method. I read its explanation in this paper paper. In the global pruning, an entire filter is removed by comparing the sum of absolute weights of each filter (further normalized by the shape of filter) globally across all the filters throughout the model. In the local pruning, a minimum filter rule is imposed, which ensures that every conv layer has at least the minimum number of filters. One single MF value (equals to 5) is applied to all the layers of the model.

I tested my implementation with a pruning ration of 50% on a pre-trained vgg19 model on cifar10 test dataset and yielded an accuracy of 89% with the original model but 10% with the pruned model. However, it was reported in the paper that accuracy of the pruned model is 68%.


(Figure2. (a)).

Here is my implementation:

import torch
import numpy as np

# Define the percentage of filters to prune
pruning_percentage = 0.5

# Collect all the filter norms in the convolutional layers
filter_norms = []
for name, module in original_model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        # Compute the sum of absolute weights divided by the filter shape
        f_shape = module.weight.data.size()
        f_shape = f_shape[2] * f_shape[3]
        weights = module.weight.data.view(module.weight.size(0), -1)
        f_norm = torch.sum(torch.abs(weights), dim=1) / f_shape
        filter_norms.extend(f_norm.cpu().numpy())

# Determine the threshold for pruning
threshold = np.percentile(np.array(filter_norms), pruning_percentage * 100)

# Prune the filters in the convolutional layers
for name, module in original_model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        # Compute the sum of absolute weights divided by the filter shape
        f_shape = module.weight.data.size()
        f_shape = f_shape[2] * f_shape[3]
        weights = module.weight.data.view(module.weight.size(0), -1)
        f_norm = torch.sum(torch.abs(weights), dim=1) / f_shape

        # Compute a mask based on the threshold
        mask = (f_norm > threshold).float().view(-1, 1, 1, 1)

        # Ensure at least five filters are kept
        if mask.sum() < 5:
            _, indices = torch.topk(f_norm, 5)
            new_mask = torch.zeros_like(f_norm)
            new_mask[indices] = 1
            new_mask = new_mask.view(-1, 1, 1, 1)
            mask = new_mask

        # Apply the mask to the weights
        module.weight.data.mul_(mask)

# Save the state dict of the pruned model
torch.save(original_model.state_dict(), './vgg19/global_pruned_cifar10_vgg19.pth')

I’ll be so happy for your help!