Hi Bin!
It depends on what you are actually asking.
The function torch.poisson() is not differentiable (even though it has
a grad_fn). There are two reasons for this: First, torch.poisson()
returns a sample from the Poisson distribution, and pytorch does not
support backpropagating through samples from distributions. Second,
Poisson samples are integer values (even though torch.poisson()
can return values of floating-point type). So you can’t really differentiate
a result that you can’t continuously vary away from its integer value.
If you try to backpropagate through torch.poisson(), you will get zeros
for the gradients (even though its result appears to have a grad_fn):
>>> import torch
>>> torch.__version__
'1.11.0'
>>> _ = torch.manual_seed (2022)
>>> rate = torch.rand (100, 100, requires_grad = True)
>>> samp = torch.poisson (rate)
>>> samp.grad_fn
<PoissonBackward0 object at 0x00000242A15A7B50>
>>> samp.sum().backward()
>>> torch.all (rate.grad == 0.0) # no useful gradients
tensor(True)
On the other hand, it does make sense to differentiate the probability (or,
for a continuous distribution, the probability density) of some sample value
with respect to parameters of the distribution.
Consider:
>>> rate = torch.tensor ([3.5], requires_grad = True)
>>> samp = torch.tensor ([4.0], requires_grad = True)
>>> lp = torch.distributions.Poisson (rate).log_prob (samp)
>>> lp
tensor([-1.6670], grad_fn=<SubBackward0>)
>>> lp.backward()
>>> rate.grad
tensor([0.1429])
Oddly enough, pytorch does return a non-trivial value for the grad of
samp, even though .log_prob (samp) refuses to accept a non-integer
value of samp:
>>> samp.grad
tensor([-0.2534])
>>> torch.distributions.Poisson (rate).log_prob (torch.tensor ([4.01]))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<path_to_pytorch_install>\torch\distributions\poisson.py", line 61, in log_prob
self._validate_sample(value)
File "<path_to_pytorch_install>\torch\distributions\distribution.py", line 286, in _validate_sample
raise ValueError(
ValueError: Expected value argument (Tensor of shape (1,)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution Poisson(rate: tensor([3.5000], requires_grad=True)), but found invalid values:
tensor([4.0100])
(It beats me what samp.grad = tensor([-0.2534]) means or where it
comes from.)
Best.
K. Frank