I have been trying to implement pruning using " `torch.nn.utils.prune" according to the tutorial here. For experimenting purposes, I am using LeNet-300-100 dense neural network with MNIST dataset which can be accessed here.
# Pruning multiple parameters-
# Prune multiple parameters/layers in a given model-
for name, module in best_model.named_modules():
# prune 20% of weights/connections in for all hidden layaers-
if isinstance(module, torch.nn.Linear) and name != 'output':
prune.l1_unstructured(module = module, name = 'weight', amount = 0.2)
# prune 10% of weights/connections for output layer-
elif isinstance(module, torch.nn.Linear) and name == 'output':
prune.l1_unstructured(module = module, name = 'weight', amount = 0.1)
# Sanity check: verify that all of the defined pruning exists as masks-
print(dict(best_model.named_buffers()).keys())
# dict_keys(['fc1.weight_mask', 'fc2.weight_mask', 'output.weight_mask'])
I then train the model with the following code snippet:
# Training loop-
for epoch in range(num_epochs):
running_loss = 0.0
running_corrects = 0.0
if loc_patience >= patience:
print("\n'EarlyStopping' called!\n")
break
running_loss, running_corrects = train_model(best_model, train_loader)
epoch_loss = running_loss / len(train_dataset)
epoch_acc = running_corrects.double() / len(train_dataset)
# epoch_acc = 100 * running_corrects / len(trainset)
# print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%\n")
running_loss_val, correct, total = test_model(best_model, test_loader)
epoch_val_loss = running_loss_val / len(test_dataset)
val_acc = 100 * (correct / total)
# print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%, val_loss = {epoch_val_loss:.4f} & val_accuracy = {val_acc:.2f}%\n")
print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%, val_loss = {epoch_val_loss:.4f} & val_accuracy = {val_acc:.2f}%")
curr_params = count_params(best_model)
print(f"Number of parameters = {curr_params}\n")
percentage_pruned = ((orig_params - curr_params.numpy()) / orig_params * 100).numpy()
# Code for manual Early Stopping:
# if np.abs(epoch_val_loss < best_val_loss) >= minimum_delta:
if (epoch_val_loss < best_val_loss) and np.abs(epoch_val_loss - best_val_loss) >= minimum_delta:
# print(f"epoch_val_loss = {epoch_val_loss:.4f}, best_val_loss = {best_val_loss:.4f}")
# update 'best_val_loss' variable to lowest loss encountered so far-
best_val_loss = epoch_val_loss
# reset 'loc_patience' variable-
loc_patience = 0
print(f"\nSaving model with lowest val_loss = {epoch_val_loss:.4f}")
# Save trained model with validation accuracy-
# torch.save(model.state_dict, f"LeNet-300-100_Trained_{val_acc}.pth")
torch.save(best_model.state_dict(), f"LeNet-300-100_{percentage_pruned:.2f}.pth")
else: # there is no improvement in monitored metric 'val_loss'
loc_patience += 1 # number of epochs without any improvement
The output is:
epoch: 1 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610Saving model with lowest val_loss = 0.0980
epoch: 2 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610epoch: 3 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610epoch: 4 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610epoch: 5 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610epoch: 6 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610‘EarlyStopping’ called!
This shows that:
-
The model is “frozen”. It doesn’t learn anything since and val_accuracy and val_loss values stay the same.
-
The number of parameters stays the same and therefore, the define pruning doesn’t seem to be having any effect.
How can I fix them?