Can I trust my own custom thresholding layer?

Pytorch beginner here. I wrote my first custom layer, which should zero-out every but not the k largest values of a tensor, e.g. Thresher{[1,3,-2,4,2]} for k=2 is [0,3,0,4,0].

The code is working, I checked it myself, but I think somethin is still wrong, because I have to set “loss.backward(retain_variables=True)” otherwise I get an error and my error is increasing rapidly.

Here is my short code:

import torch.autograd as AG

class Thresher(AG.Function):
    def __init__(self, n):
        super(Thresher, self).__init__()
        self.holds = n
    
    def forward(self, input):
        self.save_for_backward(input)
        output = input.clone()
        th = torch.sort(torch.abs(input))[0]
        threshold = th[0, output.size()[1]-self.holds]
        idx = torch.abs(input) < threshold
        output[idx] = 0
        return output
    
    def backward(self, grad_output):
        input = self.saved_tensors[0]
        grad_input = None
        output = input.clone()
        th = torch.sort(torch.abs(input))[0]
        threshold = th[0, output.size()[1]-self.holds]
        idx = torch.abs(input) >= threshold
        grad_input = torch.zeros(input.size())
        grad_input[idx] = 1
        print(grad_input)
        return grad_input

and my model:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(100, 75)
        self.fc2 = nn.Linear(75, 100)
        self.th = Thresher(4)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.th(x)
        return x

If useful I can provide a mini working example.

The first thing I’d do is to use torch.autograd.gradcheck to verify that your gradients are right.
Also, you might want to check torch.topk, which is faster than torch.sort if you only want a subset of the elements of the tensor.

2 Likes

I made it working :slight_smile: Thanks!

I think if you write this as normal Module (instead of an autograd function) you can get the backwards for automatically:

class Thresher(nn.Module):
  def __init__(self, n):
   self.holds = n

 def forward(self, input):
   abs = torch.abs(input)
   th, _ = torch.sort(abs)
   threshold = th[0, output.size()[1]-self.holds]
   mask = (abs >= threshold).type_as(input)
   return input * mask

# or just as a normal function

def threshold(input, holds):
   abs = torch.abs(input)
   th, _ = torch.sort(abs)
   threshold = th[0, output.size()[1]-holds]
   mask = (abs >= threshold).type_as(input)
   return input * mask

Thanks for your hint, but your function produces the wrong output. I guess there should be ones when values are backpropagated and not the values of the input, right?!