My understanding is that for classification tasks there is the intuition that:
(1) relu activation functions encourage sparsity, which is good (for generalization?) but that
(2) a leaky relu solves the gradient saturation problem, which relu has, at the cost of sparsity.
Is it possible, in PyTorch, to write an activation function which on the forward pass behaves like relu but which has a small positive derivative for x < 0? In other words its forward pass is relu but its backward pass is like that of a leaky relu with some specified leak rate. Of course this breaks the FTC, but that (probably?) doesn’t matter and you would have sparsity + non-saturating gradients if it worked.
I’m relatively new to deep learning, so sorry if this has already been tried. I also don’t have the experience to know if tuning activation functions leads to significant performance gains.
For ad-hoc experimentation with backward not quite equal to forward, the pattern
y = x_backward + (x_forward - x_backward).detach()
works quite well. It get’s you
x_forward in the forward, but the derivative will act as if you had
x_backward. Stay clear of
infinity - infinity which is NaN). It’s a tad more expensive than a custom autograd.Function, probably.
Thats great idea! What I was looking for.
Could you please let me know how to use it? I am little confused with x_backward? should this used after the loss.backward()?
x_backward would correspond to the activation in your
forward method, which should be used during the backward pass. In this example, you could define
x_backward as the
LeakyReLU output and
x_forward as the
This should work for what you want, all wrapped up as the activation function ‘non_sat_relu’. The backward leak is set to 0.1, but it could of course be whatever you want.
self.neg = x < 0
grad_input = grad_output.clone()
grad_input[self.neg] *= 0.1
@nthn_clmnt Just curious, when using
@staticmethod, why are we passing
Plus we can also implement this via inheriting
nn.Module as well, right?
Thank you for the detailed explanation and code. Much obliged
I am following this:
I guess the secret first argument
ctx of forward and backward is not a class instance but something called context. I think you forced me to learn something about Python, so thank you So though my code runs fine, it is suggestive of the wrong thing to call that argument self.
I haven’t seen any documentation that you can re-implement backward at the level of an
nn.Module rather than a
Just because this is a useful pattern, but inefficient for larger Tensors, here is a variant that is less concise in expression but avoids the unneeded overhead:
x_forward = torch.randn(5)
x_backward = torch.randn(5, requires_grad=True)
# the following is an efficient alternative for
# x2 = x_backward + (x_forward - x_backward).detach()
def forward(ctx, x_forward, x_backward):
ctx.shape = x_backward.shape
def backward(ctx, grad_in):
return None, grad_in.sum_to_size(ctx.shape)
x2 = FWBWMismatch.apply(x_forward, x_backward)
y = (x2**2).sum()
# check against expectation
print((x_backward.grad - 2*x_forward).max())
Hi, I would like to use torch.autograd to implement leaky relu activate function, but I had trouble in implementing LeakyReLU(x)=max(0,x)+negative_slope∗min(0,x)…can you give me some advice?
Thank you very much!
tc.clamp(x, min=0)+tc.clamp(x, max=0)*SLOPE
tc.where(x<0, x*SLOPE, X)
Than you for your reply.
But how to define in backward function?
My code is as follows, which are show “‘MyReLUBackward’ object has no attribute ‘clamp’”.
from torch.autograd import Variable
def forward(self, input, negative_slope):
output = self.clamp(input,min=0)+self.clamp(input, max=0)*negative_slope
def backward(self, grad_output):
input, = self.saved_tensors
grad_input = grad_output.clone()
input = Variable(torch.linspace(-3, 3, steps=5))
leakyrelu = MyReLU().apply
output= leakyrelu(input, 0.02)
You can follow the tutorial here. The derivatives for LeakyReLU when
x<=0. Like what @nthn_clmnt said, the argument
self shouldn’t be named “self” becuase it is very confusing, it is actually a “context” object that holds information. When
apply is called, both
backward's first argument will be the context object, followed by other arguments, so
input is the actual input you passed.