Loading modified parameters with "load_state_dict()"

I have a CNN:

class LeNet5(nn.Module):
    def __init__(self):
    # def __init__(self, beta = 1.0):
        super().__init__()
        
        # Trainable parameter for swish activation function-
        # self.beta = nn.Parameter(torch.tensor(beta, requires_grad = True))
        
        self.conv1 = nn.Conv2d(
            in_channels = 1, out_channels = 6, 
            kernel_size = 5, stride = 1,
            padding = 0, bias = False 
        )
        self.bn1 = nn.BatchNorm2d(num_features = 6)
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.conv2 = nn.Conv2d(
            in_channels = 6, out_channels = 16,
            kernel_size = 5, stride = 1,
            padding = 0, bias = False
        )
        self.bn2 = nn.BatchNorm2d(num_features = 16)
        self.fc1 = nn.Linear(
            in_features = 256, out_features = 120,
            bias = False
        )
        self.bn3 = nn.BatchNorm1d(num_features = 120)
        self.fc2 = nn.Linear(
            in_features = 120, out_features = 84,
            bias = False
        )
        self.bn4 = nn.BatchNorm1d(num_features = 84)
        self.fc3 = nn.Linear(
            in_features = 84, out_features = 10,
            bias = True
        )
        
        # self.initialize_weights()

        
    def initialize_weights(self):
        for m in self.modules():
            # print(m)
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                
                # Do not initialize bias (due to batchnorm)-
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                
            elif isinstance(m, nn.BatchNorm2d):
                # Standard initialization for batch normalization-
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
    
    
    def swish_fn(self, x):
        return x * torch.sigmoid(x * self.beta)

    
    def forward(self, x):
        
        x = nn.SiLU()(self.pool(self.bn1(self.conv1(x))))
        x = nn.SiLU()(self.pool(self.bn2(self.conv2(x))))
        x = x.view(-1, 256)
        x = nn.SiLU()(self.bn3(self.fc1(x)))
        x = nn.SiLU()(self.bn4(self.fc2(x)))
        x = self.fc3(x)
        return x
    
model = LeNet5()

I am trying to implement a global, unstructured pruning where all parameters whose magnitude is amongst the smallest 20th percentile are removed. Alternatively, smallest 20% of magnitude weights are removed/pruned. The code for this is:

def prune_globally(model, pruning_percentile = 20):
    # Python 3 list to hold layer-wise weights-
    pruned_weights = []
    
    for param in model.parameters():
        wts = np.copy(param.detach().cpu().numpy())
        pruned_weights.append(wts)
    
    del param, wts
    
    # Flatten all numpy arrays-
    pruned_weights_flattened = [layer.flatten() for layer in pruned_weights]

    threshold = np.percentile(a = abs(np.concatenate(pruned_weights_flattened)), q = pruning_percentile)
    # print("\nFor p = {0:.2f}% of weights to be pruned, threshold = {1:.4f}\n".format(p, threshold))
    
    # Prune conv and dense layers-
    # bias and batch-norm is NOT pruned.
    for layer in pruned_weights:
        if len(layer.shape) == 4:
            layer[abs(layer) < threshold] = 0
        elif len(layer.shape) == 2:
            layer[abs(layer) < threshold] = 0
    
    
    i = 0
    
    # Python3 dict as named_parameters-
    model_d = dict()

    for name, params in model.named_parameters():
        if pruned_weights[i].shape == params.shape:
            # print(f"match")
            model_d[name] = torch.from_numpy(pruned_weights[i])

        i += 1

        
    return model_d

This returns a Python3 dict. However, I cannot load this since it does not contain all of the buffers which occur along with trainable parameters in “state_dict()” as:

model.state_dict().keys()
'''
odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'conv2.weight', 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var', 'bn2.num_batches_tracked', 'fc1.weight', 'bn3.weight', 'bn3.bias', 'bn3.running_mean', 'bn3.running_var', 'bn3.num_batches_tracked', 'fc2.weight', 'bn4.weight', 'bn4.bias', 'bn4.running_mean', 'bn4.running_var', 'bn4.num_batches_tracked', 'fc3.weight', 'fc3.bias'])
'''

Whereas, the parameters I modify consists of:

model_d.keys()
dict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'conv2.weight', 'bn2.weight', 'bn2.bias', 'fc1.weight', 'bn3.weight', 'bn3.bias', 'fc2.weight', 'bn4.weight', 'bn4.bias', 'fc3.weight', 'fc3.bias'])

How can I get around this and load model_d modified dict containing pruned parameters?

The goal is to take a trained model and prune the trainable parameters. The resulting pruned parameters can be loaded into a new model.

The pruning is implemented in prune_globally() function which returns a Pyhton3 dict containing the pruned parameters. However, it does not contain non-trainable parameters, such as the buffers. Therefore, this is not equivalent to the state_dict of the model and cannot be used to load the pruned parameters.

So, what am I missing to convert model_d dict into something which I can use to load it into a model?

Could you please elaborate on what the desired behavior would be?

Your code removes mostly the parameters of batch normalization, i.e., the running mean and variance. When running your model in eval mode, the running mean and variance are used for inference.

Say, hypothetically, that your procedure would say: " conv1.weight should be removed. What would you want the resulting model to be in such case?