Neural Network Pruning - from scratch

I am trying to implement global, unstructured pruning algorithm closely based on The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks, Frankle et al. and Comparing Rewinding and Fine-tuning in Neural Network Pruning, Renda et al. I am also aware that there exists PyTorch’s pruning in torch.nn.utils.prune. My motivation for implementing things from scratch are learning and research based.

I use a toy LeNet-300-100 dense network with MNIST for this task. The class architecture is:

class LeNet300(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Define layers-
        self.fc1 = nn.Linear(in_features = 28 * 28 * 1, out_features = 300)
        self.fc2 = nn.Linear(in_features = 300, out_features = 100)
        self.output = nn.Linear(in_features = 100, out_features = 10)
        
        # self.weights_initialization()
    
    
    def forward(self, x):
        out = F.leaky_relu(self.fc1(x))
        out = F.leaky_relu(self.fc2(out))
        return self.output(out)

The function to train for one epoch is:

def train_model_progress(model, train_loader, train_dataset):
    '''
    Function to perform one epoch of training by using 'train_loader'.
    Returns loss and number of correct predictions for this epoch.
    '''
    running_loss = 0.0
    running_corrects = 0.0
    
    model.train()
    
    with tqdm(train_loader, unit = 'batch') as tepoch:
        for images, labels in tepoch:
            tepoch.set_description(f"Training: ")
            
            # images = images.view(-1, 28 * 28 * 1)
            
            images = images.to(device)
            labels = labels.to(device)
            
            # Get model predictions-
            outputs = model(images)
            
            # Compute loss-
            J = loss(outputs, labels)
            
            # Empty accumulated gradients-
            optimizer.zero_grad()
            
            # Perform backprop-
            J.backward()
            
            # Remove/Zero-out all gradients corresponding to pruned connections-
            # for name, param in model.named_parameters():
            for param in model.parameters():
                '''
                if 'mask' in name:
                    continue
                '''
                tensor = param.data.cpu().numpy()
                grad_tensor = param.grad.data.cpu().numpy()
                grad_tensor = np.where(tensor == 0., 0., grad_tensor)
                param.grad.data = torch.from_numpy(grad_tensor).to(device)

            
            # Update parameters-
            optimizer.step()
            
            '''
            global step
            optimizer.param_groups[0]['lr'] = custom_lr_scheduler.get_lr(step)
            step += 1
            '''
            
            # Compute model's performance statistics-
            running_loss += J.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            running_corrects += torch.sum(predicted == labels.data)
            
            tepoch.set_postfix(
                loss = running_loss / len(train_dataset),
                accuracy = (running_corrects.double().cpu().numpy() / len(train_dataset)) * 100
            )
            
    
    train_loss = running_loss / len(train_dataset)
    train_acc = (running_corrects.double() / len(train_dataset)) * 100

    return train_loss, train_acc.cpu().numpy()

The focus is on these critical lines:

for param in model.parameters():
                '''
                if 'mask' in name:
                    continue
                '''
                tensor = param.data
                # Or: tensor = param.data.cpu().numpy()
                grad_tensor = param.grad.data
                # Or: grad_tensor = param.grad.data.cpu().numpy()
                grad_tensor = torch.where(tensor == 0., 0., grad_tensor)
                # Or: grad_tensor = np.where(tensor == 0., 0., grad_tensor)
                param.grad.data = grad_tensor.to(device)
                # Or: param.grad.data = torch.from_numpy(grad_tensor).to(device)

The idea is simple: iterate through trainable parameters of the model and copy the parameters and its corresponding computed gradients in tensor and grad_tensor. Then, for all positions within grad_tensor, assign 0 where the corresponding weights are pruned (or, are 0s).

The problem is that this doesn’t behave as expected. Sometimes, the sparsity is preserved and most of the times, its not. For example, an unpruned LeNet-300-100 has 266610 non-zero parameters. Whereas, after 20% of global pruning, the number of surviving parameters = 213288. But, when retraining the pruned model, instead of having 213288 parameters, the model reverts back to 266610 parameters, therefore destroying pruning.

What am I missing?

Your code works for me after making it executable:

class LeNet300(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Define layers-
        self.fc1 = nn.Linear(in_features = 28 * 28 * 1, out_features = 300)
        self.fc2 = nn.Linear(in_features = 300, out_features = 100)
        self.output = nn.Linear(in_features = 100, out_features = 10)
        
        # self.weights_initialization()
    
    
    def forward(self, x):
        out = F.leaky_relu(self.fc1(x))
        out = F.leaky_relu(self.fc2(out))
        return self.output(out)
    
    
        
model = LeNet300()

# mask random parameters
with torch.no_grad():
    model.fc1.weight *= torch.randint(0, 2, (model.fc1.weight.shape))
    model.fc2.weight *= torch.randint(0, 2, (model.fc2.weight.shape))
    model.output.weight *= torch.randint(0, 2, (model.output.weight.shape))
    
zero_idx = {
    "fc1.weight": model.fc1.weight==0,
    "fc2.weight": model.fc2.weight==0,
    "output.weight": model.output.weight==0.
}

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1.)


labels = torch.randint(0, 10, (16,))
images = torch.randn(16, 28*28)
outputs = model(images)
            
J = criterion(outputs, labels)
optimizer.zero_grad()
J.backward()

# Remove/Zero-out all gradients corresponding to pruned connections-
# for name, param in model.named_parameters():
for param in model.parameters():
    '''
    if 'mask' in name:
        continue
    '''
    tensor = param.data.cpu().numpy()
    grad_tensor = param.grad.data.cpu().numpy()
    grad_tensor = np.where(tensor == 0., 0., grad_tensor)
    param.grad.data = torch.from_numpy(grad_tensor)


optimizer.step()
sd = model.state_dict()

for name in sd:
    if name in zero_idx:
        ref = zero_idx[name]
        current = sd[name] == 0.
        print("expected zero values at {}".format(ref.nonzero()))
        print("got zero values at {}".format(current.nonzero()))
        print("matching: {}".format((ref == current).all()))    

Output:

expected zero values at tensor([[  0,   1],
        [  0,   2],
        [  0,   3],
        ...,
        [299, 780],
        [299, 781],
        [299, 783]])
got zero values at tensor([[  0,   1],
        [  0,   2],
        [  0,   3],
        ...,
        [299, 780],
        [299, 781],
        [299, 783]])
matching: True
expected zero values at tensor([[  0,   0],
        [  0,   4],
        [  0,   7],
        ...,
        [ 99, 295],
        [ 99, 296],
        [ 99, 298]])
got zero values at tensor([[  0,   0],
        [  0,   4],
        [  0,   7],
        ...,
        [ 99, 295],
        [ 99, 296],
        [ 99, 298]])
matching: True
expected zero values at tensor([[ 0,  0],
        [ 0,  1],
        [ 0,  3],
        ...,
        [ 9, 92],
        [ 9, 94],
        [ 9, 97]])
got zero values at tensor([[ 0,  0],
        [ 0,  1],
        [ 0,  3],
        ...,
        [ 9, 92],
        [ 9, 94],
        [ 9, 97]])
matching: True

Hey @ptrblck, thanks for your reply!. I have uploaded an example reproducible code as a jupyter notebook which you can access here.

The code to implement global, unstructured, magnitude pruning is:

def prune_globally(model, pruning_percentile = 20):
    # Python 3 list to hold layer-wise weights-
    pruned_weights = []
    
    for param in model.parameters():
        wts = np.copy(param.detach().cpu().numpy())
        pruned_weights.append(wts)
    
    del param, wts
    
    # Flatten all numpy arrays-
    pruned_weights_flattened = [layer.flatten() for layer in pruned_weights]

    threshold = np.percentile(a = abs(np.concatenate(pruned_weights_flattened)), q = pruning_percentile)
    # print("\nFor p = {0:.2f}% of weights to be pruned, threshold = {1:.4f}\n".format(p, threshold))
    
    # Prune conv and dense layers-
    # bias and batch-norm is NOT pruned.
    for layer in pruned_weights:
        if len(layer.shape) == 4:
            layer[abs(layer) < threshold] = 0
        elif len(layer.shape) == 2:
            layer[abs(layer) < threshold] = 0
    
    
    i = 0
    model_d = dict()

    for name, params in model.named_parameters():
        if pruned_weights[i].shape == params.shape:
            model_d[name] = torch.from_numpy(pruned_weights[i])

        i += 1
        
        
    state_d = model.state_dict()

    for layer_name in model_d:
        # if pruned_model.state_dict().get(layer_name) is not None:
        if state_d.get(layer_name) is not None:
            # print(layer_name)
            state_d[layer_name] = model_d.get(layer_name)

    model.load_state_dict(state_d)
        
    return None

Within the notebook, after pruning the trained model for 20% sparsity in cell number 35 with:

prune_globally(model, pruning_percentile = sparsity_percentage[0])

The number of trainable surviving parameters = 213288. Upon re-training the pruned model in cell number 37 with:

# Train until convergence pruned initialized model-
history_pr = train_until_convergence(
    model = model, train_loader = train_loader,
    test_loader = test_loader, train_dataset = train_dataset,
    test_dataset = test_dataset, num_epochs = 20
)

The sparsity is destroyed as the number of trainable parameters falls back to 266610 and not 213288.

Something is bringing back the parameters which are supposed to stay pruned at 0??