My understanding is that for classification tasks there is the intuition that:
(1) relu activation functions encourage sparsity, which is good (for generalization?) but that
(2) a leaky relu solves the gradient saturation problem, which relu has, at the cost of sparsity.

Is it possible, in PyTorch, to write an activation function which on the forward pass behaves like relu but which has a small positive derivative for x < 0? In other words its forward pass is relu but its backward pass is like that of a leaky relu with some specified leak rate. Of course this breaks the FTC, but that (probably?) doesnât matter and you would have sparsity + non-saturating gradients if it worked.

Iâm relatively new to deep learning, so sorry if this has already been tried. I also donât have the experience to know if tuning activation functions leads to significant performance gains.

For ad-hoc experimentation with backward not quite equal to forward, the pattern

y = x_backward + (x_forward - x_backward).detach()

works quite well. It getâs you x_forward in the forward, but the derivative will act as if you had x_backward. Stay clear of NaN (and infinity - infinity which is NaN). Itâs a tad more expensive than a custom autograd.Function, probably.

Hi Tom,
Thats great idea! What I was looking for.
Could you please let me know how to use it? I am little confused with x_backward? should this used after the loss.backward()?
please help!

x_backward would correspond to the activation in your forward method, which should be used during the backward pass. In this example, you could define x_backward as the LeakyReLU output and x_forward as the ReLU output.

This should work for what you want, all wrapped up as the activation function ânon_sat_reluâ. The backward leak is set to 0.1, but it could of course be whatever you want.

@nthn_clmnt Just curious, when using @staticmethod, why are we passing self?
Plus we can also implement this via inheriting nn.Module as well, right?
Thanks!

I guess the secret first argument ctx of forward and backward is not a class instance but something called context. I think you forced me to learn something about Python, so thank you So though my code runs fine, it is suggestive of the wrong thing to call that argument self.

I havenât seen any documentation that you can re-implement backward at the level of an nn.Module rather than a torch.autograd.Function.

Just because this is a useful pattern, but inefficient for larger Tensors, here is a variant that is less concise in expression but avoids the unneeded overhead:

import torch
x_forward = torch.randn(5)
x_backward = torch.randn(5, requires_grad=True)
# the following is an efficient alternative for
# x2 = x_backward + (x_forward - x_backward).detach()
class FWBWMismatch(torch.autograd.Function):
@staticmethod
def forward(ctx, x_forward, x_backward):
ctx.shape = x_backward.shape
return x_forward
@staticmethod
def backward(ctx, grad_in):
return None, grad_in.sum_to_size(ctx.shape)
x2 = FWBWMismatch.apply(x_forward, x_backward)
y = (x2**2).sum()
y.backward()
# check against expectation
print((x_backward.grad - 2*x_forward).max())

Hi, I would like to use torch.autograd to implement leaky relu activate function, but I had trouble in implementing LeakyReLU(x)=max(0,x)+negative_slopeâmin(0,x)âŚcan you give me some advice?
Thank you very much!

Than you for your reply.
But how to define in backward function?
My code is as follows, which are show ââMyReLUBackwardâ object has no attribute âclampââ.
import torch
from torch.autograd import Variable
class MyReLU(torch.autograd.Function): @staticmethod
def forward(self, input, negative_slope):
self.save_for_backward(input)
output = self.clamp(input,min=0)+self.clamp(input, max=0)*negative_slope
return output @staticmethod
def backward(self, grad_output):
input, = self.saved_tensors
grad_input = grad_output.clone()
return grad_input

You can follow the tutorial here. The derivatives for LeakyReLU when x>0 is 1 and -NEGATIVE_SLOPE when x<=0. Like what @nthn_clmnt said, the argument self shouldnât be named âselfâ becuase it is very confusing, it is actually a âcontextâ object that holds information. When apply is called, both forward's and backward's first argument will be the context object, followed by other arguments, so input is the actual input you passed.