How to call only backward path of pytorch function?


I try to implement asymmetric threshold Function:

  1. Forward path computes as ordinary threshold, and Backward path computes as a derivative from Sigmoid function.

So, I need to call only backward path of nn.functions.sigmoid in my backward path. How can I do It? I think that it’s will be more faster than self implemented derivative of Sigmoid,

Thanks in advance!


To do so, you want to create your own Function where you reimplement the sigmoid backward.
It should be fairly easy as it is: grad_output * (1 - output) * output where output is the output of the forward pass and grad_output is the grad given as parameter for the backward.


Yes I did it:

def where(cond, x_1, x_2):
    cond = cond.float()
    return (cond * x_1) + ((1-cond) * x_2)

class Threshold(torch.autograd.Function):
    def forward(ctx, x):
        _zeros = torch.zeros_like(x)
        _ones = torch.ones_like(x)
        return where(x > 0, _ones, _zeros)

    def backward(ctx, grad_output):
        x, = ctx.saved_variables
        _slope = 100
        grad_input = _slope * torch.exp(- _slope * x) / torch.pow((1 + torch.exp(- _slope * x)), 2)
        return grad_input * grad_output

But I got a NaN values in gradients (I think it’s problem related with large slope of sigmoid, but I have not good ideas how to fix it)

I would do it slightly differently:

class AsymThreshold(torch.autograd.Function):
    def forward(ctx, *args, **kwargs):
        output = torch.nn.functional.threshold(*args, **kwargs)
        return output

    def backward(ctx, grad_output):
        output, = ctx.saved_variables
        return grad_output * (1. - output) * output

It looks good. Thanks!