What is action.reinforce(r) doing actually?

Hi,

I am studying RL with reinforcement/reinforce.py in pytorch/examples. I have some questions about it.

  1. What does action.reinforce(r) internally do ?

  2. Below is REINFORCE update rule where v_t is a return. We need to do gradient “ascent” as below but if we use optimizer.step, it is gradient “descent”. Is [action.reinforce(r)]((examples/reinforcement_learning/reinforce.py at main · pytorch/examples · GitHub) multiplying log probability by -r? Then, it makes sense.

  3. In autograd.backward(model.saved_actions, [None for _ in model.saved_actions]), what is the role of None here?

Thanks.

12 Likes
  1. It finds the .creator of the output and calls this method. Basically, it just saves the reward in the .reward attribute of the creator function. Then, when the backward method is called, the StochasticFunction class will discard the grad_output it received and pass the saved reward to the backward method.
  2. Yes, the gradient formulas are written in such a way that they negate the reward. You might not find reward.neg() there because they might have been slightly rewritten, but it’s still a gradient to be used with a descent.
  3. You need to give autograd the first list, so that it can discover all the stochastic nodes you want to optimize. The second list can look like that because the stochastic functions don’t need any gradients (they’ll discard them anyway), so you can give them None.
13 Likes

@apaszke Thanks very much!

If I got it correctly, for Vanilla Policy Gradient which updates only once for each trajectory rather than for each step in REINFORCE, we have a sum up the log-probability of policy multiplied by advantage estimate for each time step, then compute a single objective of type Variable with calling .backward() to update the parameters ?

No, you just somehow compute the total trajectory reward (probably by decaying the older rewards) and pass that to .reward. That’s all you have to do.

1 Like

In REINFORCE to reduce the variance one usually averages the loss over a number of sampled trajectories (batch_size). Do I understand correctly that in such a case I need to action.reinforce(reward / batch_size) for every action in each trajectory?

action and reward should both be tensors with batch_size as their first dimension

This is understandable but it does not explicitly answer my question: does actions.reinforce(rewards) correspond to the average or sum of errors (over the batch)?

As a side note: I really like how the code simplifies with reinforce but even after reading all available information about it I still feel uncomfortable using it.

2 Likes

I think I see what you’re asking now. The gradient estimates from all the examples in the batch are added together in the stochastic node (there’s no implicit division by the batch size) so, depending on your use case, you may want to manually divide your rewards by the batch size.

1 Like
  1. Instead of replacing the grad_output, shouldn’t you scale grad_output by the reward.
  2. The output is softmax https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py#L43 which I understand, but there is no cross-entropy loss in the case. Am I missing something?

In the naive REINFORCE method (which is used in the example), we use \Delta log \pi_\theta v(t) to do updating. Just forget cross-entropy loss. PyTorch provide inforce() method for Variable to bind the corresponding v(t) in the formula. You can try following code for checking:

x = Variable(torch.Tensor([0.1, 0.9]), requires_grad=True)
y = torch.multinomial(x, 1)  # here, make sure y.data[0] = 1
r = torch.Tensor([2]).float()
y.inforce(r)
y.backward()
print x.grad.data

Which will give you the output of [0.0, -2.2]. And -2.2 is exactly 1.0/0.9 * -v(t) (Because dlogp = 1/p)

1 Like

I implemented Vanilla Policy Gradient in this way and it works. However using the .reinforce method seems cleaner

Hi @apaszke

In this file, how could we understand the backward for Normal distribution.

To make it simple, let’s say 1D, given a mean and std, we have a sample = mean + std*eps, where eps ~ N(0, 1).

In the backward, the grad_mean = -reward*(sample - mean)/std**2.

It is not very clear to me why it is like that. Since if we have a sample, according to the formula, d_sample/d_mean = 1. So, grad_mean = upward_gradient * d_sample/d_mean

@zuoxingdong I recommend you read the paper linked at the top of the file you linked. It goes through all the derivations. The case of the normal distribution is derived in section 6, the formula you are asking about specifically is (13).

2 Likes

@abhigenie92
Agree. I am also not clear with this point. Do know how to achieve that? Is that understanding correct?

I do not notice the paper you mentioned, can you give a PDF file or the link. Thanks.

Indeed it looks like the link was removed in the last commit, which is unfortunate. The paper was Williams’92 aka “The REINFORCE paper”.

1 Like

Just in case it’s useful, I can reproduce the results from calling .reinforce with the following code:

"""
Try some reinforce by hand and similar
"""

import torch
from torch import autograd, nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np


def run_model(x, h1):
    torch.manual_seed(123)
    params = list(h1.parameters())

    x1 = h1(Variable(x))
    x2 = F.softmax(x1)
    a = torch.multinomial(x2)
    print('a', a)
    return x1, x2, a


def run_by_hand(params, x, h1):
    print('')
    print('=======')
    print('by hand')
    h1.zero_grad()
    x1, x2, a = run_model(x, h1)
    g = torch.gather(x2, 1, Variable(a.data))
    log_g = g.log()
    log_g.backward(- torch.ones(4, 1))
    print('params.grad', params.grad)


def run_pytorch_reinforce(params, x, h1):
    print('')
    print('=======')
    print('pytorch reinforce')
    h1.zero_grad()
    x1, x2, a = run_model(x, h1)
    a.reinforce(torch.ones(4, 1))
    autograd.backward([a], [None])
    print('params.grad', params.grad)


def run():
    N = 4
    K = 1
    C = 3

    torch.manual_seed(123)
    x = torch.ones(N, K)
    h1 = nn.Linear(K, C, bias=False)
    params = list(h1.parameters())[0]
    run_by_hand(params, x, h1)
    run_pytorch_reinforce(params, x, h1)


if __name__ == '__main__':
    run()

Result:

=======
by hand
a Variable containing:
 1
 1
 0
 1
[torch.LongTensor of size 4x1]

params.grad Variable containing:
 0.6170
-1.3288
 0.7117
[torch.FloatTensor of size 3x1]


=======
pytorch reinforce
a Variable containing:
 1
 1
 0
 1
[torch.LongTensor of size 4x1]

params.grad Variable containing:
 0.6170
-1.3288
 0.7117
[torch.FloatTensor of size 3x1]

(Its actually almost the same amount of code in fact?)