pytorch doesn’t have the
erfcx(x) function implemented (https://github.com/pytorch/pytorch/issues/31945). However this function is available in scipy (https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.erfcx.html). So I am trying to call the scipy version from pytorch.
I was thinking of following this tutorial: https://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html, to implement a pytorch version of
erfcx that calls the scipy version and is automatically differentiable.
Is this the recommended approach? I would like to do something that results in the greatest performance possible here.
Maybe creating the extension using scipy directly will crack the case for you. And maybe is the best thing to do in your context.
However, if you want to explore how to implement a PyTorch extension, I suggest you to take a look at the first commit of this repo. It’s a simple extension I implemented last year. Things changed quite a lot on the backend since then, but the idea on how it works is pretty much the same.
I guess, once you get it working really nice on an extension, you can try to submit a PR with your implementation. There are a lot of subtleties when it comes to implement your well working extension natively on PyTorch, so I guess is better to just know how to make a standalone extension work first.
Here is what I did:
def forward(ctx, input):
result = scipy.special.erfcx(input)
def backward(ctx, grad_output):
input, result = ctx.saved_tensors
df = -2 / np.sqrt(np.pi) + 2 * input * result
return df * grad_output
erfcx = ErfcxFunction.apply
return torch.where(x > 0, erfcx(x).log() - x**2, x.erfc().log())
return torch.where(x < 0, x.erfc().log() + x**2, erfcx(x.log()))
However this is very slow:
A = torch.randn(50,50)
%timeit logerfcx(A) # 428 µs ± 22.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit torch.erfc(A) # 13.9 µs ± 561 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Anything I can do to make this faster?
The slowness is probably due to scipy implementation of erfcx. In colab env it’s 13x slower than torch.erfc.
You could do a naive reimplementation of erfcx, although it will have numerical precision issues
def erfcx(A): return torch.exp(A*A)*torch.erfc(A)
A = torch.randn(50,50)
torch.max(abs(erfcx(A)-torch.tensor(scipy.special.erfcx(A0)))) # tensor(0.0625)
erfcx because otherwise
logerfc won’t work on the tails of the Gaussian.