How to easily modify relu layer in pytorch?

I am new to pytorch. I am trying to create a new activation layer, let’s call it topk, that would work as follows. It will take a vector x of size n as input (result of multiplying previous layer output by weight matrix and adding bias) and a positive integer k and would output a vector topk(x) of size n whose elements are
(topk(x))_i =x_i if x_i is one of the top k elements of x, 0 otherwise.

While calculating gradient of topk(x), top k elements of x should have gradient 1, everything else 0.

How should I implement this? Can you please provide some sample code?

Hi Standshik!

You can do this using torch.topk() and .scatter_():

>>> torch.__version__
'1.7.1'
>>> _ = torch.manual_seed (2021)
>>> x = torch.randperm (10).float()
>>> x.requires_grad = True
>>> x_topk_val_ind = torch.topk (x, 5)
>>> x_topk_val_ind
torch.return_types.topk(
values=tensor([9., 8., 7., 6., 5.], grad_fn=<TopkBackward>),
indices=tensor([8, 6, 3, 5, 1]))
>>> topk = torch.zeros_like (x).scatter_ (0, x_topk_val_ind[1], x_topk_val_ind[0])
>>> topk
tensor([0., 5., 0., 7., 0., 6., 8., 0., 9., 0.], grad_fn=<ScatterBackward0>)
>>> topk.sum().backward()
>>> x.grad
tensor([0., 1., 0., 1., 0., 1., 1., 0., 1., 0.])
>>> x
tensor([4., 5., 3., 7., 2., 6., 8., 0., 9., 1.], requires_grad=True)

Best.

K. Frank

Hi Frank,
Thank you so much. That was really helpful. My goal is use this as an activation in building a neural network. Suppose I define a simple network like the following where I use nn.ReLU() activation. How do I package the idea of topk together so that I can create ‘mypackage’, import it and use mypackage.topk(k) instead of nn.ReLU()?

class Net(nn.Module):
    def __init__(self, input_size, hidden_size, out_size):
        super(Net, self).__init__()                    
        self.fc1 = nn.Linear(input_size, hidden_size)                              
        self.fc2 = nn.Linear(hidden_size, out_size)
        self.relu = nn.ReLU()

    def forward(self, x):                          
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

Hi Standshik!

If it were me, I wouldn’t bother packaging it – I would just define topk
as a function (rather than a class). You could define it (either as a
function or a class) in a separate package and import it (but how to
do that is a python question, rather than specific to pytorch).

No need to instantiate a self.relu() replacement – just define your
topk function and call it in forward():

    def forward(self, x):                          
        out = self.fc1(x)
        # out = self.relu(out)
        out = my_topk_function (out, my_value_for_k)
        out = self.fc2(out)
        return out

Best.

K. Frank