Is sample operation differentiable?

Is sample operation differentiable?
When sampling from a torch distribution and trying to get parameter.grad, it says “element 0 of tensors does not require grad and does not have a grad_fn”, which I think sampling is not differentiable. Code is below.

loc = torch.tensor([0.0,2.0, 2.0, 3.0]).requires_grad_(True)
scale = torch.tensor([1.0, 1.0, 1.0, 1.0]).requires_grad_(True)
a = torch.distributions.normal.Normal(loc, scale).sample()
a.sum().backward()

However, when sampling by pyro.sample, I can get parameters.grad.

loc = torch.tensor([0.0,2.0, 2.0, 3.0]).requires_grad_(True)
scale = torch.tensor([1.0, 1.0, 1.0, 1.0]).requires_grad_(True)
a = pyro.sample("my_sample", pyro.distributions.Normal(loc, scale))
a.sum().backward()

As I know, pyro.distributions is wrapper of torch.distribution. Did pyro do something else to support grad of sample?

Are you trying to backward through the distribution somehow (which I don’t think is possible) or simply get the gradient of ·a (which I think is possible)?

>>> import torch
>>> loc = torch.tensor([0.0,2.0, 2.0, 3.0]).requires_grad_(True)
>>> scale = torch.tensor([1.0, 1.0, 1.0, 1.0]).requires_grad_(True)
>>> a = torch.distributions.normal.Normal(loc, scale).sample()
>>> a.sum().backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/eqy/.local/lib/python3.8/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/eqy/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
    Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
>>> a.requires_grad = True
>>> a.sum().backward()
>>> a.grad
tensor([1., 1., 1., 1.])
>>>

Hi Zhping!

As eqy notes, you cannot use autograd to get the gradient of a sample from
a pytorch distribution with respect to that distribution’s parameters.

It is not unreasonable, however, to sample from a standard normal distribution
and then use your parameters to transform that normal sample into your desired
gaussian sample. With this approach you can use autograd to differentiate the
gaussian sample with respect to the parameters.

Consider:

>>> import torch
>>> print (torch.__version__)
1.13.1
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> loc = torch.tensor ([0.0, 2.0, 2.0, 3.0]).requires_grad_ (True)
>>> scale = torch.tensor ([1.0, 1.0, 1.0, 1.0]).requires_grad_ (True)
>>>
>>> # you could also use  a = torch.randn (4)
>>> a = torch.distributions.normal.Normal (torch.zeros (4), torch.ones (4)).sample()
>>> print (a)
tensor([-1.2075,  0.5493, -0.3856,  0.6910])
>>>
>>> # transform standard normal samples into gaussians with your loc and scale
>>> a = scale * a + loc
>>> print (a)
tensor([-1.2075,  2.5493,  1.6144,  3.6910], grad_fn=<AddBackward0>)
>>>
>>> a.sum().backward()
>>>
>>> print (loc.grad)
tensor([1., 1., 1., 1.])
>>> print (scale.grad)
tensor([-1.2075,  0.5493, -0.3856,  0.6910])

I have no idea how pyro handles this. If you cared, you could try to reverse
engineer it by looking at the gradients you get from the pyro samples and see
if they look anything like the gradients we get in the above example.

Best.

K. Frank

1 Like

Thanks, eqy! I’m not trying to get the gradient of the sample result. I’m trying to figure out if sampling is differentiable, because I can get gradient of parameters like loc and scale by pyro. And I don’t know whether gradient of sample makes sense because sampling is stochastic.

Thank you, K. Frank! Why cannot we sample from Normal(loc, scale) directly, and choose to sample from a = Normal(0,1) and a * scale + loc ? Does the gradient make sense?
Can I understand that sampling from a distribution is meaningless bacause it is stochastic?

After reading your reply, I tried to

loc = torch.zeros(4)
scale = torch.ones(4)
a = pyro.sample("my_sample", pyro.distributions.Normal(loc, scale))
a.sum().backward()
print(f"after backward, loc.grad = {loc.grad} and scale.grad = {scale.grad}")

and it reports errors “element 0 of tensors does not require grad and does not have a grad_fn”.
I think that I know what pyro does, thank you!