How to count the number of zero weights in a pytorch model

Hi, after training a pytorch model, how do you count the total number of zero weights in the model.?

2 Likes

How about this?

def countZeroWeights(model):
    zeros = 0
    for param in model.parameters():
        if param is not None:
            zeros += torch.sum((param == 0).int()).data[0]
    return zeros
3 Likes

@jpeg729 thanks. It works

You could also go for :

zeros += param.numel() - param.nonzero().size(0)

I am not sure which one is the fastest, but I find this way a bit clearer.

3 Likes

@theevann, @jpeg729 thanks for the help i found that they both work. I have one more question to ask regarding this. If I want to remove the zeros and resave or store the model as a 2-bits instead of the the default float 32 form of a pytorch model , how do I do that?

You could store the positions of all the non-zeros:

weights.nonzero()

(You can then use scatter to fluff the model back up again, after loading (edit: you might need to .view(-1) the tensor before running .nonzero(), and then reshape it after loading/scattering)).

@jpeg729 I tested both and it seems that @theevann’s method is faster (by a factor of 3). Also, I would recommend using .item() instead of .data[0].