This should work for what you want, all wrapped up as the activation function ‘non_sat_relu’. The backward leak is set to 0.1, but it could of course be whatever you want.
import torch
class NSReLU(torch.autograd.Function):
@staticmethod
def forward(self,x):
self.neg = x < 0
return x.clamp(min=0.0)
@staticmethod
def backward(self,grad_output):
grad_input = grad_output.clone()
grad_input[self.neg] *= 0.1
return grad_input
def non_sat_relu(x):
return NSReLU.apply(x)