When do we call the prune.remove() function when using the pruning toolbox to prune a cnn?

Hi,

I am using the pruning tool box to prune a resnet18. I see that there is a function called prune.remove() which aims to make the pruning permanent. When should this function be called ideally?

If i train a model, and then perform l1 structured pruning, make the pruning permanent by calling the prune.remove() and then finetune, the performance of the model is differnet from if i train the model, perform l1 structured pruning, fine tune and then make the pruning permanent right before deployment. I experimented with pruning_amount of 0.7, and the former method’s value is almost equal to the baseline unpruned model while the latter’s performance drops significantly which I think should be the case since we lost 70% of the filters in each layer.

Can somebody please shed some light on this. I have my pruning code below

    prune_strategy = prune.ln_structured
    addn_params = {'n':1, 'dim':1}
    path_to_checkpoint = 'results/{}/checkpoint_200.pth'.format(checkpoint_folder)
    checkpoint = torch.load(path_to_checkpoint)
    #model.load_state_dict(checkpoint['model_state_dict'])
    model.load_state_dict(checkpoint)
    print("Model loaded successfully : ", path_to_checkpoint)
    num_params_before, num_params_after = 0, 0
    for name, module in model.named_modules():
        # print(name, module)
        # not pruning fc layers
        if isinstance(module, torch.nn.Conv2d):
            prune_strategy(module, name='weight', amount=args.prune_amount, **addn_params)
            prune.remove(module, name='weight')
            # if I run the prune.remove here while the layers are being pruned, the performace of the model 
            # is as good as the baseline. If i comment it out, the performance drops.

Thank you,
Uzair

Hi there,
I am working on pruning too. And I find this walkthrough helpful.
check it out The Lottery Ticket Hypothesis and pruning in PyTorch - YouTube