# Sparsing the elements of tensors

Hi Pytorch Community
I asked once before but no one answered to my question so now i will ask it again and i will appreciate if you can provide me with any help.
I have a dictionary which contains a few tensors(Fisher matrix of a CNN) like below:
Fisher = {conv1.weight = 'a tensor of shape(6,3,5,5), conv1.bias = …,}
now i want to keep the 50 percent biggest elements in the fisher matrix of the entire CNN(whole dictionary) and set the rest of them to 0 so i will have a sparse version of it.
I will appreciate if anyone can help me with that

Wouldn’t indexing the tensor(s) based on a mask to select the lagest elements, and zeroing out the others work or are you processing the entire parameter set somehow first?

1 Like

Well, the use case is that i want to put a penalty in my loss function based on those tensors which are the diagonal of the fisher information matrix of another model so i can keep the models parameters close to each other. now i want those tensors to be sparse.
and i want to sparse the entire parameter set.
so if i have a fisher information like below:
fisher = {conv1.weight = a tensor of shape(6,3,5,5),…, fc1.weight = a tensor of shape(200,400),…}
i want to keep the 50 percent largest elements of the entire ‘fisher’

Should the sparsity be “learned” by the model to keep the “50% largest elements”? If so, I would unfortunately not be experienced enough in this area and wouldn’t know how to achieve this.

No we don’t need the model here. i was trying to say how the background works and where that dictionary came from. now we have the dictionary and that dictionary has some keys and those keys have tensor values assigned to them.
I want a function that takes that dictionary as input and keep 50 percent largest elements in the whole dictionary and make the rest 0 and output the sparsed dictionary.
I don’t know if i was able to explain what i want properly or not. sorry for my bad explanation

In that case masking should work:

``````model = models.resnet18()
sd = model.state_dict()

for key in sd:
print(key)
print('abs.sum before {}'.format(sd[key].abs().sum()))