Pruning with "torch.nn.utils.prune" module - no pruning seen

I have been trying to implement pruning using " `torch.nn.utils.prune" according to the tutorial here. For experimenting purposes, I am using LeNet-300-100 dense neural network with MNIST dataset which can be accessed here.

# Pruning multiple parameters-
# Prune multiple parameters/layers in a given model-
for name, module in best_model.named_modules():
    # prune 20% of weights/connections in for all hidden layaers-
    if isinstance(module, torch.nn.Linear) and name != 'output':
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.2)
    
    # prune 10% of weights/connections for output layer-
    elif isinstance(module, torch.nn.Linear) and name == 'output':
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.1)


# Sanity check: verify that all of the defined pruning exists as masks-
print(dict(best_model.named_buffers()).keys())
# dict_keys(['fc1.weight_mask', 'fc2.weight_mask', 'output.weight_mask'])

I then train the model with the following code snippet:

# Training loop-
for epoch in range(num_epochs):
    running_loss = 0.0
    running_corrects = 0.0
    
    if loc_patience >= patience:
        print("\n'EarlyStopping' called!\n")
        break

    running_loss, running_corrects = train_model(best_model, train_loader)
  
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)
    # epoch_acc = 100 * running_corrects / len(trainset)
    # print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%\n")

    running_loss_val, correct, total = test_model(best_model, test_loader)

    epoch_val_loss = running_loss_val / len(test_dataset)
    val_acc = 100 * (correct / total)
    # print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%, val_loss = {epoch_val_loss:.4f} & val_accuracy = {val_acc:.2f}%\n")

    print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%, val_loss = {epoch_val_loss:.4f} & val_accuracy = {val_acc:.2f}%")

    curr_params = count_params(best_model)
    print(f"Number of parameters = {curr_params}\n")
    
    percentage_pruned = ((orig_params - curr_params.numpy()) / orig_params * 100).numpy()
    
    # Code for manual Early Stopping:
    # if np.abs(epoch_val_loss < best_val_loss) >= minimum_delta:
    if (epoch_val_loss < best_val_loss) and np.abs(epoch_val_loss - best_val_loss) >= minimum_delta:
        # print(f"epoch_val_loss = {epoch_val_loss:.4f}, best_val_loss = {best_val_loss:.4f}")
        
        # update 'best_val_loss' variable to lowest loss encountered so far-
        best_val_loss = epoch_val_loss
        
        # reset 'loc_patience' variable-
        loc_patience = 0
        
        print(f"\nSaving model with lowest val_loss = {epoch_val_loss:.4f}")
        
        # Save trained model with validation accuracy-
        # torch.save(model.state_dict, f"LeNet-300-100_Trained_{val_acc}.pth")
        torch.save(best_model.state_dict(), f"LeNet-300-100_{percentage_pruned:.2f}.pth")
        
    else:  # there is no improvement in monitored metric 'val_loss'
        loc_patience += 1  # number of epochs without any improvement

The output is:

epoch: 1 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610

Saving model with lowest val_loss = 0.0980

epoch: 2 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610

epoch: 3 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610

epoch: 4 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610

epoch: 5 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610

epoch: 6 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610

‘EarlyStopping’ called!

This shows that:

  1. The model is “frozen”. It doesn’t learn anything since and val_accuracy and val_loss values stay the same.

  2. The number of parameters stays the same and therefore, the define pruning doesn’t seem to be having any effect.

How can I fix them?

When you use the torch.nn.utils.prune module in PyTorch to prune a neural network, it does not actually remove the parameters from the model. Instead, it masks them as zero with a mask matrix during inference. This can be useful in reducing the model size and computational requirements, but it may not provide a true reduction in the number of parameters.

If you are looking for a tool that can truly remove the pruned parameters from your model, this tool GitHub - VainF/Torch-Pruning: [CVPR-2023] Towards Any Structural Pruning may be useful.

1 Like