How to apply threshold on model weights while training?

Let’s say I have a model with a bunch of linear layers (nn.Linear) . Let’s say a layer ‘fc1’ has weight parameter of shape (m, n). For each row, I want to keep only top-k values (magnitude wise).

What is the correct way to implement this while training the model ? I looked at torch.prune which looks similar but I am not sure how to use it or if there is another way to do it correctly ?

Hi Mayank!

Use topk() to get the indices of the smallest n - k values in your
weight and then use scatter() to zero them out:

>>> import torch
>>> print (torch.__version__)
2.3.1
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> m = 3
>>> n = 5
>>> k = 2
>>>
>>> t = torch.randn (3, 5)
>>>
>>> t
tensor([[-0.0404,  1.7260, -0.8140,  1.3722,  0.5060],
        [-0.4823, -0.7853,  0.6681, -0.4439,  0.1888],
        [ 0.5986,  0.6458,  0.6306, -1.4668, -0.6798]])
>>>
>>> ind = t.abs().topk (n - k, dim = 1, largest = False).indices
>>> t_topk = t.scatter(1, ind, torch.zeros (ind.size()))
>>>
>>> t_topk
tensor([[ 0.0000,  1.7260,  0.0000,  1.3722,  0.0000],
        [ 0.0000, -0.7853,  0.6681,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -1.4668, -0.6798]])

(Note, I believe you want to zero out the smaller values because removing
them from the weight (and changing the shape of the tensor) will likely
break the connections between the various layers in your model.)

Rather than apply this logic to your layers after every optimization step
(which would likely un-zero your zero values), you would probably want
to package this logic as a pytorch parametrization and register it for each
layer that you want to treat this way.

Best.

K. Frank