Setting invalid actions to zero lets error explode in Policy Gradient

I am trying to implement a basic Policy Gradient training setup. In the environment that I use not every action is allowed in each state and therefore I set invalid action probabilities to zero.

The problem is that then the policy gradient loss starts to explode and then the network does not learn anything anymore until a nan error occurs.

I tested the same code without setting invalid actions to zero and then it is able to converge and learn the proper actions. So without line 3-5 and using pActions for legalActions everything works fine.

#does not work
pActions, values = self.forward(observations)
legalActions = pActions.clone()
legalActions[1 - valids] = 0
legalActions = legalActions/legalActions.sum(dim=1).unsqueeze(-1).expand(legalActions.size())
actions = torch.multinomial(legalActions, 1).long()
logProbs = torch.log(torch.gather(legalActions, -1, actions))

#works
pActions, values = self.forward(observations)
actions = torch.multinomial(pActions, 1).long()
logProbs = torch.log(torch.gather(pActions, -1, actions))

So what is wrong with my implementation of this? Does the cloning or manually setting to zero somehow affect the backward pass in the wrong way? And did anyone ever had a working setup for invalid actions in Pytorch? If yes, how did you implement it?

Edit
I found out that if I do the normalization on pActions instead of legalActions and use pActions for selecting the logProbs it works.

#works
pActions, values = self.forward(observations)
legalActions = pActions.clone()
legalActions[1 - valids] = 0
pActions = pActions/legalActions.sum(dim=1).unsqueeze(-1).expand_as(pActions)
actions = torch.multinomial(legalActions, 1).long()
logProbs = torch.log(torch.gather(pActions, -1, actions))

Why is that? Does the cloning somehow prevent the gradient to properly flow back to the network?

The log of 0 is bound to cause trouble. Also note that your probs are unnormalized after dropping states.
You could instead try legalActions = pActions [valids] instead…

Best regards

Thomas

Hm but log(0) should actually never happen since the probabilities at this actions are zero and therefore they are never selected. Or am I wrong?

I do normalize the new selected states by dividing with the sum. Or do you mean something different?