I am trying to train the following policy gradient network, however the weights do not update. not sure what is the problem. Can someone help, please?
class Policy(nn.Module):
def __init__(self, obs_shape, hidden_size = 64):
super(Policy, self).__init__()
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0), np.sqrt(2))
self.actor = nn.Sequential(
nn.Linear(obs_shape, hidden_size), nn.BatchNorm1d(hidden_size) , nn.ReLU(),
nn.Linear(hidden_size, hidden_size),nn.BatchNorm1d(hidden_size), nn.ReLU()
,nn.Linear(hidden_size, 3))
def forward(self, inputs):
x = inputs
hidden_actor = self.actor(x)
return hidden_actor
def act(self, inputs, deterministic=False):
actor_features = self.forward(inputs)
action_prob = F.softmax(actor_features, dim = -1)
dist = distributions.Categorical(action_prob)
if deterministic:
action = dist.mode()
else:
action = dist.sample()
action_log_probs = dist.log_prob(action)
dist_entropy = dist.entropy().mean()
return action, action_log_probs
def evaluate_actions(self, inputs, action):
actor_features = self.forward(inputs)
#dist = self.dist(actor_features)
action_prob = F.softmax(actor_features, dim = -1)
dist = distributions.Categorical(action_prob)
action_log_probs = dist.log_prob(action)
return action_log_probs
The train function is the following:
def train(env, policy, optimizer, discount_factor):
policy.train()
log_prob_actions = []
rewards = []
states = []
actions = []
done = False
episode_reward = 0
state = env.reset()
while not done:
state = torch.FloatTensor(state).unsqueeze(0).to(device)
states.append(state)
policy.eval()
with torch.no_grad():
action, log_prob_action = policy.act(state)
policy.train()
actions.append(action)
state, reward, done, _ = env.step(action.item())
#print(log_prob_action)
log_prob_actions.append(log_prob_action)
rewards.append(reward)
episode_reward += reward
log_prob_actions = torch.cat(log_prob_actions)
returns = calculate_returns(rewards, discount_factor).to(device)
loss, grad_norm = update_policy_version2(returns, states, actions, optimizer)
return loss, grad_norm, episode_reward, states, actions, returns, rewards
Where update policy function is:
def update_policy_version2(returns, states, actions, optimizer):
states = torch.cat(states)
actions = torch.stack(actions).squeeze()
log_prob_actions = policy.evaluate_actions(states, actions)
returns = returns.detach()
loss = - (returns * log_prob_actions).sum()
optimizer.zero_grad()
grad_loss = torch.autograd.grad(loss, policy.parameters(), retain_graph=True)
#print(grad_loss)
grad_loss = torch.cat([g.contiguous().view(-1) for g in grad_loss])
norm = torch.norm(grad_loss)
optimizer.step()
params = []
for param in policy.parameters():
params.append(param.view(-1))
for p in policy.parameters():
print(p.grad)
return loss.item(), norm.item()
The print function for p.grad returns none while the grad_loss calculated using torch.autograd.grad returns real values. The step() function does not update the weights.