Hi, after training a pytorch model, how do you count the total number of zero weights in the model.?
How about this?
def countZeroWeights(model): zeros = 0 for param in model.parameters(): if param is not None: zeros += torch.sum((param == 0).int()).data return zeros
@jpeg729 thanks. It works
You could also go for :
zeros += param.numel() - param.nonzero().size(0)
I am not sure which one is the fastest, but I find this way a bit clearer.
@theevann, @jpeg729 thanks for the help i found that they both work. I have one more question to ask regarding this. If I want to remove the zeros and resave or store the model as a 2-bits instead of the the default float 32 form of a pytorch model , how do I do that?
You could store the positions of all the non-zeros:
(You can then use
scatter to fluff the model back up again, after loading (edit: you might need to
.view(-1) the tensor before running .nonzero(), and then reshape it after loading/scattering)).