Pruning in pytroch

I get that there is a pruning option in PyTorch. All it does it make some of the weights 0. My doubt is altho they are zero those edges still exist right? Won’t that affect the backpropagation algorithm?

I believe it depends on the implementation; if you are referring to functions in torch.nn.utils.prune, I would expect that it would behave as expected in the truly “sparse” setting of removing weights as a mask is used for computation rather than simply zeroing the weights.

Consider the following toy example:

>>> import torch
>>> a = torch.zeros(10, requires_grad=True)
>>> a.sum().backward()
>>> a.grad
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
>>> a2 = torch.zeros(10, requires_grad=True)
>>> mask = torch.round(torch.rand(10, requires_grad=False))
>>> mask
tensor([1., 0., 0., 0., 1., 1., 0., 1., 1., 0.])
>>> (a2*mask).sum().backward()
>>> a2.grad
tensor([1., 0., 0., 0., 1., 1., 0., 1., 1., 0.])
>>>
1 Like