Question about how the -m.log_prob() function in torch.distributions.bernoulli works?


I am trying to figure out what the -m.log_prob() function is actually doing when implementing Policy gradients.

I have stepped through the m.log_prob() function in the debugger many times - and fail to see how it is working. My understanding is that m.log_prob() calls

m.log_prob() >> which calls logits() >> when then calls cross_entropy_with_logits()

this is all fine, except that I cannot recreate the value being created in the logits function.

I get that logits() is just the log odds, so log( p /(1 − p ) but when inputting the same value for p (generated by the net on line 1 below , probs = p) in into the equation for logits by hand, fail to recreate the same value produced by PyTorch’s logit function!

The reason for all of this being that I think this function is at the core of what is going on - and because Pytorch is so efficient-obscures “the magic” of what is actually going on under the hood in Vanilla Policy Gradients (aka REINFORCE).

specifically, I am referring to what is going on inside the -m.log_prob() function here:

  1. probs = policy_network(state)
  2. Note that this is equivalent to what used to be called multinomial
  3. m = Categorical(probs)
  4. action = m.sample()
  5. next_state, reward = env.step(action)
  6. loss = -m.log_prob(action) * reward
  7. loss.backward()

this example is straight out of the pytorch documentation here:

1 Like

Just checking in here - does it not just call .gather on the logits?