Modifying gradients stop training

I am trying to train a pruned neural network which has (say) 36% sparsity which means that 36% of the trainable parameters are 0s. I am using a simple LeNet-300-100 dense network for this:

class LeNet300(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Define layers-
        self.fc1 = nn.Linear(in_features = 28 * 28 * 1, out_features = 300)
        self.fc2 = nn.Linear(in_features = 300, out_features = 100)
        self.output = nn.Linear(in_features = 100, out_features = 10)
        
        # self.weights_initialization()
    
    
    def forward(self, x):
        out = F.leaky_relu(self.fc1(x))
        out = F.leaky_relu(self.fc2(out))
        return self.output(out)

After 36% sparsity, the number of trainable parameters = 213288. Originally, without pruning, it has 266610 non-zero parameters.

The layer-wise sparsity looks as:

for param in model.parameters():
    print(f"{param.size()} has {torch.count_nonzero(param)} surviving weights")

'''
torch.Size([300, 784]) has 146615 surviving weights
torch.Size([300]) has 300 surviving weights
torch.Size([100, 300]) has 22743 surviving weights
torch.Size([100]) has 100 surviving weights
torch.Size([10, 100]) has 862 surviving weights
torch.Size([10]) has 10 surviving weights
'''

The code I am using to preserve sparsity is:

In .grad (gradient_t), replace computed gradients with 0s in place positions (within the matrix) where the trainable parameters (wts) are pruned:

# Compute loss-
J = loss(outputs, labels)
            
# Empty accumulated gradients-
optimizer.zero_grad()
            
# Perform backprop-
J.backward()
            
for name, param in model.named_parameters():
    wts = param.data.clone().detach()
    gradient_t = param.grad
    gradient_t = torch.where(wts == 0., 0., gradient_t)
    param.grad = gradient_t
        
# Update parameters-
optimizer.step()

However, after this, the training basically freezes, and the loss and accuracy stop changing.

?

Hi Arjun!

In principle your scheme should work. Consider this example script:

import torch
print (torch.__version__)

s = torch.tensor ([2.0, 0.0, -1.0], requires_grad = True)   # a trainable parameter
t = torch.tensor ([1.0, 1.0, 1.0])

optimizer = torch.optim.SGD ([s], lr = 0.2)                 # plain-vanilla SGD

for  i in range (10):
    loss = ((s - t)**2).sum()
    optimizer.zero_grad()
    loss.backward()
    wts = s.data.clone().detach()                           # don't use .data -- it is deprecated
    gradient_t = s.grad
    gradient_t = torch.where (wts == 0., 0., gradient_t)
    s.grad = gradient_t
    optimizer.step()
    print (s)                                               # s trains nicely with s[1] frozen

Here is its output:

.0.1
tensor([ 1.6000,  0.0000, -0.2000], requires_grad=True)
tensor([1.3600, 0.0000, 0.2800], requires_grad=True)
tensor([1.2160, 0.0000, 0.5680], requires_grad=True)
tensor([1.1296, 0.0000, 0.7408], requires_grad=True)
tensor([1.0778, 0.0000, 0.8445], requires_grad=True)
tensor([1.0467, 0.0000, 0.9067], requires_grad=True)
tensor([1.0280, 0.0000, 0.9440], requires_grad=True)
tensor([1.0168, 0.0000, 0.9664], requires_grad=True)
tensor([1.0101, 0.0000, 0.9798], requires_grad=True)
tensor([1.0060, 0.0000, 0.9879], requires_grad=True)

After what, specifically?

What does “basically freezes” mean. Do your weights stop changing
entirely or does your training just slow down some?

Do you get any non-zero gradients? When you call optimizer.step()
do the weights with non-zero gradients change at all? Do your outputs
(predictions) change at all? If your predictions change some, do your
loss and accuracy change at least a little bit, or do then not change at
all?

As an aside, although it isn’t the cause of your problem, you shouldn’t
use .data, as it’s deprecated an can break things. It also isn’t necessary
in your case, but I would suggest modifying the gradients that you want to
zero out inplace, and wrapping that whole bit in a with torch.no_grad():
block.

Best.

K. Frank

Hey Frank,

Thanks for your detailed answer. Can you provide an example of modifying the gradients by 0ing them out inplace?

@KFrank I have uploaded an example reproducible code as a jupyter notebook which you can access here.

The code to implement global, unstructured, magnitude pruning is:

def prune_globally(model, pruning_percentile = 20):
    # Python 3 list to hold layer-wise weights-
    pruned_weights = []
    
    for param in model.parameters():
        wts = np.copy(param.detach().cpu().numpy())
        pruned_weights.append(wts)
    
    del param, wts
    
    # Flatten all numpy arrays-
    pruned_weights_flattened = [layer.flatten() for layer in pruned_weights]

    threshold = np.percentile(a = abs(np.concatenate(pruned_weights_flattened)), q = pruning_percentile)
    # print("\nFor p = {0:.2f}% of weights to be pruned, threshold = {1:.4f}\n".format(p, threshold))
    
    # Prune conv and dense layers-
    # bias and batch-norm is NOT pruned.
    for layer in pruned_weights:
        if len(layer.shape) == 4:
            layer[abs(layer) < threshold] = 0
        elif len(layer.shape) == 2:
            layer[abs(layer) < threshold] = 0
    
    
    i = 0
    model_d = dict()

    for name, params in model.named_parameters():
        if pruned_weights[i].shape == params.shape:
            model_d[name] = torch.from_numpy(pruned_weights[i])

        i += 1
        
        
    state_d = model.state_dict()

    for layer_name in model_d:
        # if pruned_model.state_dict().get(layer_name) is not None:
        if state_d.get(layer_name) is not None:
            # print(layer_name)
            state_d[layer_name] = model_d.get(layer_name)

    model.load_state_dict(state_d)
        
    return None

Within the notebook, after pruning the trained model for 20% sparsity in cell number 35 with:

prune_globally(model, pruning_percentile = sparsity_percentage[0])

The number of trainable surviving parameters = 213288. Upon re-training the pruned model in cell number 37 with:

# Train until convergence pruned initialized model-
history_pr = train_until_convergence(
    model = model, train_loader = train_loader,
    test_loader = test_loader, train_dataset = train_dataset,
    test_dataset = test_dataset, num_epochs = 20
)

The sparsity is destroyed as the number of trainable parameters falls back to 266610 and not 213288.

Something is bringing back the parameters which are supposed to stay pruned at 0??

Hi Arjun!

In general, I would look to .copy_(), but in this case I might use .mul_():

with  torch.no_grad():
    for  name, param in model.named_parameters():
        param.grad.mul_ (~(param == 0.0))

Best.

K. Frank

I changed the train function to include your code:

def train_model_progress(
    model, 
    train_loader, train_dataset
):
    '''
    Function to perform one epoch of training by using 'train_loader'.
    Returns loss and number of correct predictions for this epoch.
    '''
    running_loss = 0.0
    running_corrects = 0.0
    
    model.train()
    
    with tqdm(train_loader, unit = 'batch') as tepoch:
        for images, labels in tepoch:
            tepoch.set_description(f"Training: ")
            
            images = images.view(-1, 28 * 28 * 1)
            
            images = images.to(device)
            labels = labels.to(device)
            
            # Get model predictions-
            outputs = model(images)
            
            # Compute loss-
            J = loss(outputs, labels)
            
            # Empty accumulated gradients-
            optimizer.zero_grad()
            
            # Perform backprop-
            J.backward()
            
            
            with  torch.no_grad():
                for  name, param in model.named_parameters():
                    param.grad.mul_(~(param == 0.0))
            
            # Update parameters-
            optimizer.step()
                        
            # Compute model's performance statistics-
            running_loss += J.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            running_corrects += torch.sum(predicted == labels.data)
            
            tepoch.set_postfix(
                loss = running_loss / len(train_dataset),
                accuracy = (running_corrects.double().cpu().numpy() / len(train_dataset)) * 100
            )
            
    
    train_loss = running_loss / len(train_dataset)
    train_acc = (running_corrects.double() / len(train_dataset)) * 100

    return train_loss, train_acc.cpu().numpy()

But after the first pruning round to remove 20% of parameters, the number of trainable parameters falls back to the original number of parameters, therefore destroying pruning - Alas!

@KFrank I think optimizer.step() destroys sparsity. Look at this minimal code:

images, labels = next(iter(train_loader))
images = images.view(-1, 28 * 28 * 1)
images = images.to(device)
labels = labels.to(device)
            
# Get model predictions-
outputs = model(images)
            
# Compute loss-
J = loss(outputs, labels)
            
# Empty accumulated gradients-
optimizer.zero_grad()
            
# Perform backprop-
J.backward()

# before grad multiplication-
for param in model.parameters():
    print(f"{param.size()}; wts = {torch.count_nonzero(param)}",
          f" & grads = {torch.count_nonzero(param.grad)}"
         )
'''
torch.Size([300, 784]); wts = 185620  & grads = 235200
torch.Size([300]); wts = 300  & grads = 300
torch.Size([100, 300]); wts = 26318  & grads = 30000
torch.Size([100]); wts = 100  & grads = 100
torch.Size([10, 100]); wts = 940  & grads = 1000
torch.Size([10]); wts = 10  & grads = 10
'''

# after grad multiplication-
with  torch.no_grad():
    # for name, param in model.named_parameters():
    for param in model.parameters():
        param.grad.mul_(~(param == 0.0))
        
        print(f"{param.size()}; wts = {torch.count_nonzero(param)}",
          f" & grads = {torch.count_nonzero(param.grad)}"
         )
'''
torch.Size([300, 784]); wts = 185620  & grads = 185620
torch.Size([300]); wts = 300  & grads = 300
torch.Size([100, 300]); wts = 26318  & grads = 26318
torch.Size([100]); wts = 100  & grads = 100
torch.Size([10, 100]); wts = 940  & grads = 940
torch.Size([10]); wts = 10  & grads = 10
'''

count_surviving_params(model)
# 213288

optimizer.step()

count_surviving_params(model)
# 266610

Just after optimizer.step() is called, the pruned parameters seem to be back and tosses sparsity out of the window?!?

Hi Arjun!

You haven’t told us which optimizer you are using, but some optimizers,
even SGD with non-zero momentum, will modify parameters whose gradients
are (or have been forced to be) zero.

My post above shows a simple example where your scheme works with
plain-vanilla SGD.

You have two choices: Zero out the gradients and use an optimizer that
does not modify parameters that have zero gradients; or zero out the
parameters themselves after optimizer.step().

Best.

K. Frank

In Lottery Ticket Hypothesis and other similar papers, the authors use Adam and/or SGD optimizer.

The code that I had implemented in TensorFlow 2.x is:

# Define 'train_one_step()' and 'test_step()' functions here-
    @tf.function
    def train_one_step(model, mask_model, optimizer, x, y):
        '''
        Function to compute one step of gradient descent optimization
        '''
        with tf.GradientTape() as tape:
            # Make predictions using defined model-
            y_pred = model(x)

            # Compute loss-
            loss = loss_fn(y, y_pred)

        # Compute gradients wrt defined loss and weights and biases-
        grads = tape.gradient(loss, model.trainable_variables)

        # type(grads)c
        # list

        # List to hold element-wise multiplication between-
        # computed gradient and masks-
        grad_mask_mul = []

        # Perform element-wise multiplication between computed gradients and masks-
        for grad_layer, mask in zip(grads, mask_model.trainable_weights):
            grad_mask_mul.append(tf.math.multiply(grad_layer, mask))

        # Apply computed gradients to model's weights and biases-
        optimizer.apply_gradients(zip(grad_mask_mul, model.trainable_variables))

        # Compute accuracy-
        train_loss(loss)
        train_accuracy(y, y_pred)

        return None

mask_model is a binary mask of 0/1 denoting which parameter has been pruned, and which is not. Here, the sparsity is maintained with grad_mask_mul (computed as element-wise multiplication between gradients and binary mask, and:

optimizer.apply_gradients(zip(grad_mask_mul, model.trainable_variables))

due to multiplication of gradients with the binary mask (stored in grad_mask_mul) which is then applied to model’s trainable parameters.

This worked for SGD and Adam.

Is there something similar in PyTorch ?