Smooth continuation

I would like to define the function

image

This function is infinitely differentiable. It’s derivatives in zero are all zero.

def f(x):
    return (-1/x).exp() * (x > 0)

x = torch.tensor(0.0, requires_grad=True)
y = f(x)
y.backward()
x.grad  # NaN

How can I make torch.autograd return the correct derivatives (all of them)?

note: I tried with torch.autograd.Function but I had a problem with the second derivative.

1 Like

Hi,

This is expected I’m afraid. The chain rule that the autograd uses only works at points where the function is differentiable. But 1/x here is not, hence the nan.

You can indeed use a custom Function to achieve that.
Can you share your issue with that?
One thing to keep in mind for double backward is that you should not save intermediary Tensors and re-use them. Only input/output. You can make intermediary Tensors dummy outputs if needed here.

It worked! Thanks

class SmoothStep(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        y = torch.zeros_like(x)
        m = (x > 0.0)
        y[m] = (-1/x[m]).exp()
        return y

    @staticmethod
    def backward(ctx, dy):
        x, = ctx.saved_tensors
        dx = torch.zeros_like(x)
        m = (x > 0.0)
        xm = x[m]
        dx[m] = (-1/xm).exp() / xm.pow(2)
        return dx * dy
    
f = SmoothStep.apply
x = torch.linspace(-1, 1, 1000, requires_grad=True)

y0 = f(x)
y1, = torch.autograd.grad(y0.sum(), x, create_graph=True)
y2, = torch.autograd.grad(y1.sum(), x, create_graph=True)
y3, = torch.autograd.grad(y2.sum(), x, create_graph=True)

assert torch.isfinite(y0).all() and torch.isfinite(y1).all() and torch.isfinite(y2).all() and torch.isfinite(y3).all()

image

1 Like