Hey, looking for the n choose k function similar to scipy.special.comb any ideas where I can find it? Thanks!
While there is torch.special in PyTorch 1.9, it is not yet very complete.
I would suggest to use the scipy function or implement it directly. Note that PyTorch integer tensors are 64 bits by default, things can get large quickly with these factorials.
Assuming that the scipy function computes the conventional binomial
coefficient, you may use:
((n + 1).lgamma() - (k + 1).lgamma() - ((n - k) + 1).lgamma()).exp()
Binomial coefficients may be computed in terms of factorials, the
gamma function is the factorial function with its argument shifted
by one, and pytorch implements an autograd-aware log-gamma
So you can use
lgamma() to build a binomial-coefficient function
(that is differentiable and autograd aware):
>>> import torch >>> torch.__version__ '1.7.1' >>> def combo (n, k): ... return ((n + 1).lgamma() - (k + 1).lgamma() - ((n - k) + 1).lgamma()).exp() ... >>> tfive = torch.tensor ([5.0], requires_grad = True) >>> tthree = torch.tensor ([3.0], requires_grad = True) >>> c = combo (tfive, tthree) >>> c tensor([10.], grad_fn=<ExpBackward>) >>> c.backward() >>> tfive.grad tensor([7.8333]) >>> tthree.grad tensor([-3.3333])
Note that the binomial-coefficient function grows exponentially (as
does the factorial function), so in many situations is better to work
with logarithms. If that fits your use case, just drop the trailing
in the definition of
(The beta function is more or less the binomial-coefficient function
so it would be nice if pytorch implemented a
lbeta() function, but
it’s hardly necessary.)