Here is what I did:
class ErfcxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
result = scipy.special.erfcx(input)
ctx.save_for_backward(input, result)
return result
@staticmethod
def backward(ctx, grad_output):
input, result = ctx.saved_tensors
df = -2 / np.sqrt(np.pi) + 2 * input * result
return df * grad_output
erfcx = ErfcxFunction.apply
def logerfc(x):
return torch.where(x > 0, erfcx(x).log() - x**2, x.erfc().log())
def logerfcx(x):
return torch.where(x < 0, x.erfc().log() + x**2, erfcx(x.log()))
However this is very slow:
A = torch.randn(50,50)
%timeit logerfcx(A) # 428 µs ± 22.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit torch.erfc(A) # 13.9 µs ± 561 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Anything I can do to make this faster?