Pytorch beginner here. I wrote my first custom layer, which should zero-out every but not the k largest values of a tensor, e.g. Thresher{[1,3,-2,4,2]} for k=2 is [0,3,0,4,0].
The code is working, I checked it myself, but I think somethin is still wrong, because I have to set “loss.backward(retain_variables=True)” otherwise I get an error and my error is increasing rapidly.
Here is my short code:
import torch.autograd as AG
class Thresher(AG.Function):
def __init__(self, n):
super(Thresher, self).__init__()
self.holds = n
def forward(self, input):
self.save_for_backward(input)
output = input.clone()
th = torch.sort(torch.abs(input))[0]
threshold = th[0, output.size()[1]-self.holds]
idx = torch.abs(input) < threshold
output[idx] = 0
return output
def backward(self, grad_output):
input = self.saved_tensors[0]
grad_input = None
output = input.clone()
th = torch.sort(torch.abs(input))[0]
threshold = th[0, output.size()[1]-self.holds]
idx = torch.abs(input) >= threshold
grad_input = torch.zeros(input.size())
grad_input[idx] = 1
print(grad_input)
return grad_input
and my model:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(100, 75)
self.fc2 = nn.Linear(75, 100)
self.th = Thresher(4)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.th(x)
return x
If useful I can provide a mini working example.