Backward through Bernoulli sample

I am trying to backpropagate through a Bernoulli sample. A simplified version of my code is:

...
parameter = torch.nn.Parameter(torch.ones(8), requires_grad=True) # parameters to learn
prob = torch.sigmoid(parameter) #sigmoid to make a probability [0,1]
bern_sample = torch.distributions.bernoulli.Bernoulli(prob).sample() # draw a sample 

loss = torch.sum(bern_sample)
loss.backward()

Intuitively the variable parameter should be updated by gradient descent in order to make all probabilities close to 0, i.e. bern_sample should contain all zeros.

However when I check gradients wrt to parameter, i.e. parameter.grad, they are None.

What am I doing wrong?

Thanks in advance

Hello Ruggero!

Sampling from a discrete distribution isn’t differentiable. So pytorch’s
autograd won’t (and can’t) backpropagate any gradients from
subsequent functions back through the Bernoulli-sampling step.

The value returned from sampling a Bernoulli distribution (or any
discrete distribution) is a discrete value – 0.0 or 1.0. So you can’t
differentiate it. That is, you can’t make your sample value a little
smaller, say 0.99 instead of 1.0, by making your parameter prob
a little smaller.

Best.

K. Frank

1 Like

Hello K Frank,
I’m just wondering if sampling from other distribution is a differentiable operation or not?

Hi Anurag!

I do not believe that pytorch supports differentiating a sample from a
(continuous) distribution with respect to parameters describing that
distribution. (But I haven’t checked.)

However, doing so can make mathematical sense, and you can
sometimes structure your calculations so that it works.

Consider a gaussian distribution with mean mu, and standard
deviation sigma. You can sample from such a distribution by
first sampling a value z from a normal distribution (mean = 0,
standard deviation = 1), and then transforming that value:

x = sigma * z + mu

If mu and sigma are packaged as Tensors that have
requires_grad = True, then calling x.backward() will
compute gradients for sigma and mu.

(Whether you will get the result you expect depends on what you
were expecting, but this scheme is conceptually self-consistent.)

Best.

K. Frank

I solved the problem by using the reparametrization trick with Gumbel Softmax. I took inspiration from this: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f

Thanks for your help anyway