'Normal' object has no attribute 'rsample'

Trying to implement the pathwise derivative for a stochastic policy as mentioned here. From the documentation:

Another way to implement these stochastic/policy gradients would be to use the reparameterization trick from rsample() method, where the parameterized random variable can be defined as a parameterized deterministic function of a parameter-free random variable. The reparameterized sample is required to be differentiable. The code for implementing the pathwise estimation would be as follows:

params = policy_network(state)
m = Normal(*params)
# any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action)  # Assume that reward is differentiable
loss = -reward
loss.backward()

Where Iā€™m assuming that the params are the mean actions from a normal distribution over each action (some clarification on this would be good). However, when I implement this, I get the error given in the title. My action selection function is:

def select_action(state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    mu, state_value = model(Variable(state))
    m = torch.distributions.Normal(mu, env.action_space.shape[0])
    action = m.rsample()
    model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
    return action.data[0]

Did I make a mistake here? Is there a working example of the pathwise derivative available for learning?

Cheers

2 Likes

Maybe you were using PyTorch 0.3? .rsample() is only available in PyTorch 0.4