Hi, after training a pytorch model, how do you count the total number of zero weights in the model.?

2 Likes

How about this?

```
def countZeroWeights(model):
zeros = 0
for param in model.parameters():
if param is not None:
zeros += torch.sum((param == 0).int()).data[0]
return zeros
```

3 Likes

You could also go for :

```
zeros += param.numel() - param.nonzero().size(0)
```

I am not sure which one is the fastest, but I find this way a bit clearer.

3 Likes

@theevann, @jpeg729 thanks for the help i found that they both work. I have one more question to ask regarding this. If I want to remove the zeros and resave or store the model as a 2-bits instead of the the default float 32 form of a pytorch model , how do I do that?

You could store the positions of all the non-zeros:

```
weights.nonzero()
```

(You can then use `scatter`

to fluff the model back up again, after loading (edit: you might need to `.view(-1)`

the tensor before running .nonzero(), and then reshape it after loading/scattering)).