Hello everyone. I am using the DDPG implementation provided by GitHub - ghliu/pytorch-ddpg: Implementation of the Deep Deterministic Policy Gradient (DDPG) using PyTorch and adapting it to my purposes. However, I am facing some issues with the propagation of the actor’s loss as I am getting None
for its gradients. I will attach a code snippet of the function and all the helper functions that it uses.
def optimize(self):
if self.rm.len < (self.size_buffer):
return
self.state_encoder.eval()
state, idx, action, set_actions, reward, next_state, curr_perf, curr_acc, done = self.rm.sample(self.batch_size)
state = torch.from_numpy(state)
next_state = torch.from_numpy(next_state)
set_actions = torch.from_numpy(set_actions)
action = torch.from_numpy(action)
reward = [r[-1] for r in reward]
reward = np.expand_dims(np.array(reward), axis = 1)
reward = torch.from_numpy(np.array(reward))
reward = reward.cuda()
done = np.expand_dims(done, axis = 1)
terminal = torch.from_numpy(done)
terminal = terminal.cuda()
# ------- optimize critic ----- #
state = state.cuda()
next_state = next_state.cuda()
a_pred = self.target_actor(next_state)
pred_perf = self.train_actions(set_actions, a_pred.data, idx, terminal)
pred_perf = torch.from_numpy(pred_perf)
new_set_states = torch.Tensor()
for idx_s, single_state in enumerate(next_state):
new_state = single_state
if done[idx_s]:
next_indx = int(idx[idx_s])
else:
if idx[idx_s] < 5:
next_indx = int(idx[idx_s] + 1)
else:
next_indx = int(idx[idx_s])
new_state[next_indx, :] = self.state_encoder(a_pred[idx_s].data.cpu().float(), pred_perf[idx_s].cpu().float())
new_state = new_state[None, :]
new_set_states = torch.cat((new_set_states, new_state.cpu()), dim = 0)
new_set_states = torch.from_numpy(np.array(new_set_states))
new_set_states = new_set_states.cuda()
target_values = torch.add(reward, torch.mul(~terminal, self.target_critic(new_set_states)))
val_expected = self.critic(next_state)
criterion = nn.MSELoss()
loss_critic = criterion(target_values, val_expected)
print(val_expected, target_values, loss_critic)
self.critic_optimizer.zero_grad()
loss_critic.backward()
self.critic_optimizer.step()
for name, param in self.critic.named_parameters():
print('here', name, param.grad, param.requires_grad, param.is_leaf)
# ----- optimize actor ----- #
pred_a1 = self.actor(state)
# pred_perf = self.train_actions(idx, pred_a1, curr_arch, inputs)
pred_perf = self.train_actions(set_actions, pred_a1.data, idx, terminal)
pred_perf = torch.from_numpy(pred_perf)
new_set_states = torch.Tensor()
for idx_s, single_state in enumerate(state):
new_state = single_state
if done[idx_s]:
next_indx = int(idx[idx_s])
else:
if idx[idx_s] < 5:
next_indx = int(idx[idx_s] + 1)
else:
next_indx = int(idx[idx_s])
new_state[next_indx, :] = self.state_encoder(pred_a1[idx_s].data.cpu().float(), pred_perf[idx_s].cpu().float())
new_state = new_state[None, :]
new_set_states = torch.cat((new_set_states, new_state.cpu()), dim = 0)
new_set_states = torch.from_numpy(np.array(new_set_states))
new_set_states = new_set_states.cuda()
# self.actor_optimizer.zero_grad()
# print(a_pred, pred_perf)
# print(self.critic(new_set_states), -1*self.critic(new_set_states), (-1*self.critic(new_set_states)).mean())
# loss_actor = (-1*self.critic(new_set_states)).mean()
loss_fn = CustomLoss(self.actor, self.critic)
loss_actor = loss_fn(new_set_states)
# print('loss_actor', loss_actor)
self.actor_optimizer.zero_grad()
loss_actor.backward()
self.actor_optimizer.step()
for name, param in self.actor.named_parameters():
print('here', name, param.grad, param.requires_grad, param.is_leaf)
# self.actor_optimizer.step()
# self.losses['actor_loss'].append(loss_actor.item())
self.losses['critic_loss'].append(loss_critic.item())
TAU = 0.001
self.utils.soft_update(self.target_actor, self.actor, TAU)
self.utils.soft_update(self.target_critic, self.critic, TAU)
The code for the CustomLoss is as follows:
class CustomLoss(nn.Module):
def __init__(self, actor, critic):
super(CustomLoss, self).__init__()
self.actor = actor
self.critic = critic
def forward(self, state):
loss = torch.mul(-1, self.critic(state))
loss = loss.mean()
return loss
Any help is really appreciated.