Backwared through bernoullie sample

I need some information and help regarding how to do the Bernoullie distribution sampling using weights. can you please help me with this if you have a free time?.

What kind of issues are you facing?

1 Like

check if RelaxedBernoulli (supporting rsample) would do what you want. For more advanced use, modeling uncertainty, sample Bernoulli parameter from a trainable Beta distribution:

pseudo-code is something like:

P = Beta(softplus(raw_a), softplus(raw_b)) # raw_a = nn.Parameter(…)
p = P.rsample()
B = RelaxedBernoulli(probs=p)
b = B.rsample()
loglik = P.log_prob(p) + B.log_prob(b) #for proper probabilistic models


I am facing probelm
RuntimeError: Expected p_in >= 0 && p_in <= 1 to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

for the following problem
class SigmoidForwardLinearBackward(torch.autograd.Function):
def forward(ctx, ins):
# return torch.nn.functional.sigmoid(ins)

     return torch.bernoulli(ins)

def backward(ctx, grad_output):
    return grad_output

MyFunction = SigmoidForwardLinearBackward.apply


sigmoids = []

bernoulli = []
backwards = []
for x in torch.arange(-5, 5, 0.01):
x.requires_grad = True
outs = MyFunction(x)

thank you so much for providing me information. can you please provide me bit more clear information please.

  1. torch.bernoulli derivative is zero (due to discrete output)
  2. your autograd.Function doesn’t make sense, because your “function”'s derivative doesn’t depend on its input. you don’t handle the problem of discreteness, and can as well supply any “external” binary tensor to the network with the same effect (no backprop beyond it).
  3. RelaxedBernoulli is a workaround, that works by allowing continuous outputs (samples) during training. It is not an easy way, but you can read about it - [1611.00712] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables

another alternative, esp. if you can’t use continuous 0…1 outputs, is to use a “score function”, similar to how it is described in the docs