I am using " torch.nn.utils.prune" to prune a LeNet- CNN as mentioned in the PyTorch Pruning Tutorial. The code is as follows:
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 3x3 square conv kernel
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
# Define parameters to prune-
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
# Define unstructured global pruning-
prune.global_unstructured(
parameters_to_prune,
pruning_method = prune.L1Unstructured,
amount = 0.2
)
def calculate_sparsity(model):
'''
Computes layer-wise & global sparsity
Refer-
https://discuss.pytorch.org/t/how-to-count-the-number-of-zero-weights-in-a-pytorch-model/13549
'''
num_zeros = 0
num_wts_global = 0
for param in model.parameters():
# print(param.shape)
loc_zeros = torch.sum((param == 0).int()).item()
num_zeros += loc_zeros
num_wts_global += param.numel()
print(f"layer.shape = {param.shape}, # of wts = {param.numel()} & # of zeros = {loc_zeros}")
print(f"\nGlobal sparsity = {num_zeros / num_wts_global:.2f}%")
However, executing the function gives the output-
calculate_sparsity(model)
layer.shape = torch.Size([6]), # of wts = 6 & # of zeros = 0
layer.shape = torch.Size([6, 1, 3, 3]), # of wts = 54 & # of zeros = 0
layer.shape = torch.Size([16]), # of wts = 16 & # of zeros = 0
layer.shape = torch.Size([16, 6, 3, 3]), # of wts = 864 & # of zeros = 0
layer.shape = torch.Size([120]), # of wts = 120 & # of zeros = 0
layer.shape = torch.Size([120, 400]), # of wts = 48000 & # of zeros = 0
layer.shape = torch.Size([84]), # of wts = 84 & # of zeros = 0
layer.shape = torch.Size([84, 120]), # of wts = 10080 & # of zeros = 0
layer.shape = torch.Size([10]), # of wts = 10 & # of zeros = 0
layer.shape = torch.Size([10, 84]), # of wts = 840 & # of zeros = 0Global sparsity = 0.00%
This clearly shows that ‘param’ is not able to access the pruned weights which are stored as ‘weight’ attribute and can be accessed as-
conv1_sparsity = (torch.sum(best_model.conv1.weight == 0) / best_model.conv1.weight.nelement()) * 100
conv2_sparsity = (torch.sum(best_model.conv2.weight == 0) / best_model.conv2.weight.nelement()) * 100
conv3_sparsity = (torch.sum(best_model.conv3.weight == 0) / best_model.conv3.weight.nelement()) * 100
Is there a way to access ‘weight’ from a loop as used in “calculate_sparsity()”?