Hi Mayank!
Use topk()
to get the indices of the smallest n - k
values in your
weight and then use scatter()
to zero them out:
>>> import torch
>>> print (torch.__version__)
2.3.1
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> m = 3
>>> n = 5
>>> k = 2
>>>
>>> t = torch.randn (3, 5)
>>>
>>> t
tensor([[-0.0404, 1.7260, -0.8140, 1.3722, 0.5060],
[-0.4823, -0.7853, 0.6681, -0.4439, 0.1888],
[ 0.5986, 0.6458, 0.6306, -1.4668, -0.6798]])
>>>
>>> ind = t.abs().topk (n - k, dim = 1, largest = False).indices
>>> t_topk = t.scatter(1, ind, torch.zeros (ind.size()))
>>>
>>> t_topk
tensor([[ 0.0000, 1.7260, 0.0000, 1.3722, 0.0000],
[ 0.0000, -0.7853, 0.6681, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, -1.4668, -0.6798]])
(Note, I believe you want to zero out the smaller values because removing
them from the weight (and changing the shape of the tensor) will likely
break the connections between the various layers in your model.)
Rather than apply this logic to your layers after every optimization step
(which would likely un-zero your zero values), you would probably want
to package this logic as a pytorch parametrization and register it for each
layer that you want to treat this way.
Best.
K. Frank