Counting pruned parameters

I am using " torch.nn.utils.prune" to prune a LeNet- CNN as mentioned in the PyTorch Pruning Tutorial. The code is as follows:

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

# Define parameters to prune-
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

# Define unstructured global pruning-
prune.global_unstructured(
    parameters_to_prune,
    pruning_method = prune.L1Unstructured,
    amount = 0.2
)


def calculate_sparsity(model):
	'''
	Computes layer-wise & global sparsity
	
	Refer-
	https://discuss.pytorch.org/t/how-to-count-the-number-of-zero-weights-in-a-pytorch-model/13549
	'''
	num_zeros = 0
	num_wts_global = 0

	for param in model.parameters():
		# print(param.shape)
		loc_zeros = torch.sum((param == 0).int()).item()
		num_zeros += loc_zeros
		num_wts_global += param.numel()
		print(f"layer.shape = {param.shape}, # of wts = {param.numel()} & # of zeros = {loc_zeros}")

	print(f"\nGlobal sparsity = {num_zeros / num_wts_global:.2f}%")

However, executing the function gives the output-

calculate_sparsity(model)

layer.shape = torch.Size([6]), # of wts = 6 & # of zeros = 0
layer.shape = torch.Size([6, 1, 3, 3]), # of wts = 54 & # of zeros = 0
layer.shape = torch.Size([16]), # of wts = 16 & # of zeros = 0
layer.shape = torch.Size([16, 6, 3, 3]), # of wts = 864 & # of zeros = 0
layer.shape = torch.Size([120]), # of wts = 120 & # of zeros = 0
layer.shape = torch.Size([120, 400]), # of wts = 48000 & # of zeros = 0
layer.shape = torch.Size([84]), # of wts = 84 & # of zeros = 0
layer.shape = torch.Size([84, 120]), # of wts = 10080 & # of zeros = 0
layer.shape = torch.Size([10]), # of wts = 10 & # of zeros = 0
layer.shape = torch.Size([10, 84]), # of wts = 840 & # of zeros = 0

Global sparsity = 0.00%

This clearly shows that ‘param’ is not able to access the pruned weights which are stored as ‘weight’ attribute and can be accessed as-

conv1_sparsity = (torch.sum(best_model.conv1.weight == 0) / best_model.conv1.weight.nelement()) * 100

conv2_sparsity = (torch.sum(best_model.conv2.weight == 0) / best_model.conv2.weight.nelement()) * 100

conv3_sparsity = (torch.sum(best_model.conv3.weight == 0) / best_model.conv3.weight.nelement()) * 100

Is there a way to access ‘weight’ from a loop as used in “calculate_sparsity()”?