Hi,
I am trying to figure out what the -m.log_prob() function is actually doing when implementing Policy gradients.
I have stepped through the m.log_prob() function in the debugger many times - and fail to see how it is working. My understanding is that m.log_prob() calls
m.log_prob() >> which calls logits() >> when then calls cross_entropy_with_logits()
this is all fine, except that I cannot recreate the value being created in the logits function.
I get that logits() is just the log odds, so log( p /(1 − p ) but when inputting the same value for p (generated by the net on line 1 below , probs = p) in into the equation for logits by hand, fail to recreate the same value produced by PyTorch’s logit function!
The reason for all of this being that I think this function is at the core of what is going on - and because Pytorch is so efficient-obscures “the magic” of what is actually going on under the hood in Vanilla Policy Gradients (aka REINFORCE).
specifically, I am referring to what is going on inside the -m.log_prob() function here:
- probs = policy_network(state)
- Note that this is equivalent to what used to be called multinomial
- m = Categorical(probs)
- action = m.sample()
- next_state, reward = env.step(action)
- loss = -m.log_prob(action) * reward
- loss.backward()
this example is straight out of the pytorch documentation here: https://pytorch.org/docs/stable/distributions.html?highlight=reinforce