I am trying to implement a basic Policy Gradient training setup. In the environment that I use not every action is allowed in each state and therefore I set invalid action probabilities to zero.
The problem is that then the policy gradient loss starts to explode and then the network does not learn anything anymore until a nan error occurs.
I tested the same code without setting invalid actions to zero and then it is able to converge and learn the proper actions. So without line 3-5 and using pActions for legalActions everything works fine.
#does not work
pActions, values = self.forward(observations)
legalActions = pActions.clone()
legalActions[1 - valids] = 0
legalActions = legalActions/legalActions.sum(dim=1).unsqueeze(-1).expand(legalActions.size())
actions = torch.multinomial(legalActions, 1).long()
logProbs = torch.log(torch.gather(legalActions, -1, actions))
#works
pActions, values = self.forward(observations)
actions = torch.multinomial(pActions, 1).long()
logProbs = torch.log(torch.gather(pActions, -1, actions))
So what is wrong with my implementation of this? Does the cloning or manually setting to zero somehow affect the backward pass in the wrong way? And did anyone ever had a working setup for invalid actions in Pytorch? If yes, how did you implement it?
Edit
I found out that if I do the normalization on pActions instead of legalActions and use pActions for selecting the logProbs it works.
#works
pActions, values = self.forward(observations)
legalActions = pActions.clone()
legalActions[1 - valids] = 0
pActions = pActions/legalActions.sum(dim=1).unsqueeze(-1).expand_as(pActions)
actions = torch.multinomial(legalActions, 1).long()
logProbs = torch.log(torch.gather(pActions, -1, actions))
Why is that? Does the cloning somehow prevent the gradient to properly flow back to the network?