How to prune weights in a network

Hello everyone,
I am a new Pytorch user, so I am fairly inexperienced. How can I zero out some weights of a fully connected layer in Pytorch? Effectively i want to set some weights equal to zero and avoid changing that value as the training of the network progresses. Can anyone point me to some solutions?

1 Like

You basically need to access the right parameters and set them to 0. Then set their property .requires_grad = False to make sure they don’t change in value.

The way to access weights in PyTorch is to use .parameters() function. Any model’s children can be accessed using model.children().

for child in model.children():
    for param in child.parameters():
        <access specific parameters and set them to 0, and set requires_grad to False>

Hope this helps!

If i follow this solution I can deactivate only all the weights at the same time but not a specific element of the weight matrix. At the end of the day i want to have 2 layers where the connections are sparse, so y = Wx where W the weight matrix with some elements zero during the whole training process

Here, a more detailed answer.

I have a tutorial which shows accessing specific elements in detail. It doesn’t set the weight to zero and requires grad to False, but shows how to access the parameters. Here’s the github repo for the tutorial - https://github.com/Spandan-Madan/A-Collection-of-important-tasks-in-pytorch

In the tutorial go to cell 22 and modify the lines to be.

for child in model.children():
for param in child.parameters():
param = torch.zeros(param.size())
param.requires_grad = False

This will basically prune all nodes of the model. To specify which nodes are to be pruned just add if statements in the above code to prune only the required nodes! Just add another for loop to loop over the elements within child.parameters().

Hope this helps!

1 Like

hello,i have same question.do you solve your problem?

There are some out-of-the-box pruning functionalities directly in PyTorch now, starting from v1.4, that will handle this for you, for a subset of pruning techniques. You can find them under torch.nn.utils.prune .