How to call SciPy functions?

pytorch doesn’t have the erfcx(x) function implemented (https://github.com/pytorch/pytorch/issues/31945). However this function is available in scipy (https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.erfcx.html). So I am trying to call the scipy version from pytorch.

I was thinking of following this tutorial: https://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html, to implement a pytorch version of erfcx that calls the scipy version and is automatically differentiable.

Is this the recommended approach? I would like to do something that results in the greatest performance possible here.

Maybe creating the extension using scipy directly will crack the case for you. And maybe is the best thing to do in your context.

However, if you want to explore how to implement a PyTorch extension, I suggest you to take a look at the first commit of this repo. It’s a simple extension I implemented last year. Things changed quite a lot on the backend since then, but the idea on how it works is pretty much the same.

I guess, once you get it working really nice on an extension, you can try to submit a PR with your implementation. There are a lot of subtleties when it comes to implement your well working extension natively on PyTorch, so I guess is better to just know how to make a standalone extension work first.

1 Like

Here is what I did:

class ErfcxFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        result = scipy.special.erfcx(input)
        ctx.save_for_backward(input, result)
        return result
    
    @staticmethod
    def backward(ctx, grad_output):
        input, result = ctx.saved_tensors
        df = -2 / np.sqrt(np.pi) + 2 * input * result
        return df * grad_output

erfcx = ErfcxFunction.apply

def logerfc(x):
    return torch.where(x > 0, erfcx(x).log() - x**2, x.erfc().log())

def logerfcx(x):
    return torch.where(x < 0, x.erfc().log() + x**2, erfcx(x.log()))

However this is very slow:

A = torch.randn(50,50)
%timeit logerfcx(A) # 428 µs ± 22.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit torch.erfc(A) # 13.9 µs ± 561 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

Anything I can do to make this faster?

The slowness is probably due to scipy implementation of erfcx. In colab env it’s 13x slower than torch.erfc.

You could do a naive reimplementation of erfcx, although it will have numerical precision issues

torch.manual_seed(123456)
def erfcx(A): return torch.exp(A*A)*torch.erfc(A)
A = torch.randn(50,50)
A0=A.numpy()

torch.max(abs(erfcx(A)-torch.tensor(scipy.special.erfcx(A0)))) # tensor(0.0625)

I need erfcx because otherwise logerfc won’t work on the tails of the Gaussian.