Hi, after training a pytorch model, how do you count the total number of zero weights in the model.?
2 Likes
How about this?
def countZeroWeights(model):
zeros = 0
for param in model.parameters():
if param is not None:
zeros += torch.sum((param == 0).int()).data[0]
return zeros
3 Likes
You could also go for :
zeros += param.numel() - param.nonzero().size(0)
I am not sure which one is the fastest, but I find this way a bit clearer.
3 Likes
@theevann, @jpeg729 thanks for the help i found that they both work. I have one more question to ask regarding this. If I want to remove the zeros and resave or store the model as a 2-bits instead of the the default float 32 form of a pytorch model , how do I do that?
You could store the positions of all the non-zeros:
weights.nonzero()
(You can then use scatter
to fluff the model back up again, after loading (edit: you might need to .view(-1)
the tensor before running .nonzero(), and then reshape it after loading/scattering)).