Hi @gemsanyou,
I think my code is not correct. I forgot to zero gradients for each action. This should be correct:
actor_optimizer.zero_grad()
action_Q = actor(state_batch)
policy_output = -critic(state_batch,action_Q)
for i in range(policy_output.size()[0]):
policy_output[i].backward(retain_graph=True)
for name, param in actor.named_parameters():
params_grads[name] += invert_gradient(param.grad.data, action_Q.detach()[i])
actor_optimizer.zero_grad()
for name, param in actor.named_parameters():
param.grad.data = params_grads[name]/batch_size
params_grads[name] = torch.zeros_like(param.data)
actor_optimizer.step()
Regarding your questions:
- I see no problem to apply it to discrete actions as long as you know the action bounds
- If each action have different bounds you can change the code of
invert_gradient
function to have as argument the bounds, as follows:
def invert_gradient(grads, action, max_b, min_b):
pdiff_max = torch.div(-action+max_b, max_b - min_b)
pdiff_min = torch.div(action-min_b, max_b - min_b)
zeros_grad = torch.zeros_like(grads)
grad_inverter = torch.zeros_like(grads)
grad_inverter = torch.where(torch.gt(grads, zeros_grad), torch.mul(grads,pdiff_max), torch.mul(grads,pdiff_min))
return grad_inverter
Then, somewhere outside the trainning you should store the action bounds for example in a dict called actions_bounds
. Then the trainning loop should be
actor_optimizer.zero_grad()
action_Q = actor(state_batch)
policy_output = -critic(state_batch,action_Q)
for i in range(policy_output.size()[0]):
policy_output[i].backward(retain_graph=True)
max_b, min_b = actions_bounds[i]
for name, param in actor.named_parameters():
params_grads[name] += invert_gradient(param.grad.data, action_Q.detach()[i], max_b, min_b)
actor_optimizer.zero_grad()
for name, param in actor.named_parameters():
param.grad.data = params_grads[name]/batch_size
params_grads[name] = torch.zeros_like(param.data)
actor_optimizer.step()
Regarding the final remarks, the inverting gradient formula tries to downscale the gradient if the action is close to the upper or lower limit in order to avoid the behaviour of a policy on actions on the bounds.