How to prune weights in PyTorch


(Luyu Wang) #1

I’m trying to replicate the work of Han et al (Learning both Weights and Connections for Efficient Neural Networks, 2015), where model compression for the deep CNN models is achieved by pruning close-to-zero weights and then retrain the model. It has two training phases: in the first stage the model is trained as usual, which is used to find weights below a certain threshold; then those insignificant weights are pruned, resulting in a simpler model, and the rest parameters are kept for another fine-tuning training session.

My idea of implementation using PyTorch is that given the trained model from the first stage, I set weights below the threshold to zero (memorized by pruned_inds_by_layer), and then start the second training stage, in which I don’t allow any gradient to be back-propagated to those zero-valued weights. But it seems modifying p.grad.data below doesn’t do the work. Those zero-valued weights still get gradients, making them non-zero again. Any idea how to solve this problem?

optimizer.zero_grad()
outputs = cnn(images)
loss = criterion(outputs, labels)
loss.backward()

# zero-out all the gradients corresponding to the pruned connections
for l,p in enumerate(cnn.parameters()):
    pruned_inds = pruned_inds_by_layer[l]
    p.grad.data[pruned_inds] = 0.

optimizer.step()

(Luyu Wang) #2

I figure it out myself: I was using adam to optimize which accumulates gradients form previous steps and therefore setting gradients to zero by hand has no effect on the momentum term in adam. Using RMSprop solves this problem.


(Lolong) #3

hi, after pruning the model, did you notice the speed up in cpu/gpu?


(Luyu Wang) #4

Not for my implementation. I was just creating a mask on the cnn module - weights are not actually pruned but just manually set to zero.


(Lolong) #5

I see. I think, after zeroing the weights, new architecture definition has to be created manually, for instance number of neurons per layer and copying the parameters with a bit of engineering.

However, the paper says that they train the connectivity between neurons of the previous and current layers. I guess it is a bit of different than your implementation, isn’t it?


(Luyu Wang) #6

I think you are right, the code I have here prunes the weight gradients after the backpropagation. Contributions from pruned weights are still included, which is incorrect. Well, it is a hacky attempt I had - to do it better we need to create masks for each layer using register_register_hook I think.

Any thought?


(Edgar Medina) #7

I read the paper available in this link ( https://arxiv.org/pdf/1506.02626.pdf ), and he mencioned weights are masked also. I think zero weights may be implemented in hardware without problems.

I’m not sure about it. I’m trying to implement this method as well.


(Issam H Laradji) #8

I am curious whether zeroing the gradients speeds up the operation, does the Pytorch implementation of .backward() support sparse weight update ?

If there is no speed up, what stops you from zeroing the weights (instead of the gradients) after taking the optimization step ? in this case, the optimizer class doesn’t matter, unless you are trying to avoid accumulating meta information in the optimizer such as momentum.

You can prune the gradients by changing the forward pass as follows.

If you have W = Variable(torch.randn(10), requires_grad=True) and you would like the gradients of the first 5 coordinates only, maybe you can do something like this in the forward pass,

loss = (torch.mv(X[:, :5], W[:5]) +
           torch.mv(X[:, 5:], Variable(W[5:].data)) - y)**2
loss = loss.sum()

# Compute grad for W[:5] only
loss.backward()

which computes the gradient for W[:5] only.


(Edgar Medina) #9

Hello,

I think this is a better way to do it (mentioned as the beginning):

# zero-out all the gradients corresponding to the pruned connections
for l,p in enumerate(cnn.parameters()):
    pruned_inds = pruned_inds_by_layer[l]
    p.grad.data[pruned_inds] = 0.

Because I have the mask (zero-weights), so I can update only non zero weights. However, my weights continue updating. I will try other options as well :slight_smile:


(Issam H Laradji) #10

If the goal is to update few parameters, then the method mentioned in the top post should be very slow. That’s because when you call backwards() all the gradients are computed anyway so you lose the speed advantage of computing only few gradients.

You can change the forward pass using the zero-weight mask so that only the gradients of the non-zero weights are computed.


(Issam H Laradji) #11

On another note, I believe pruning the gradients using your code will result in the wrong update. That’s because gradients depend on each other. So if you zero some gradients for some parameters in some layer, then the gradients in the earlier layers should change (and might not necessarily be zero) due to the chain rule. Your code does not change the dependent gradients. So the code is probably slow and will result in the wrong update.


(Edgar Medina) #12

I implemented a similar version to gradients update correctly (based on the top publication) after set some weights to zero. This update works because gradient must be variables or None:

    for l,p in enumerate(net.parameters()):
        gradient_mask = (p!=0).data # Assuming some weights in 0
        p.grad = Variable(gradient_mask.float() * p.grad.data)

you are right in the first comment, and I am agree. The algorithm takes the same time as the non-pruned version, due to I continue computing the gradients in zero weights.

However in this:

I believe pruning the gradients using your code will result in the wrong update.

My original model had 91.77% accuracy, after pruning, the my model had 71.78%. I retrained the model (without updating the zero weights and setting gradients to zero) and I got 92.18% for 25 epochs using lr=0.05 with SGD and learning rate is divided by 10 in 3, 10, 16 epoch. I’m not sure if it is a fine-tuning, but I improved the previous result. :grinning:

Moreover, could you explain a little bit more the example you gave? I get lose in some part when I try to employ it in my model.

Thank you in advance


(Luyu Wang) #13

So we have an implementation of weights pruning on this repo. The idea is to create a wrapper on the linear or conv layer, and apply the mask on the forward pass. Since multiplying the mask is a differentiable operation (multiplying constant essentially), PyTorch AutoDiff will take care of the backward pass automatically.

Well still you may not expect any acceleration, the goal of this implementation is to study the properties of pruned network.


#14

Hi, is your code final by now?
And why won’t it be faster?

I have been struggling with similar question recently, I set the data below the threshold to zero and also set the gradient of them to zero. But based on your discussion, I think I should create mask for the loss instead?

So now both my gradient is incorrect and the operation is super slow, any thoughts on that?
Thanks a lot.
My code is:


(Edgar Medina) #15

Unfortunatelly, I had the same time performance problem. I did not know how to solve it, but I worked :slight_smile: