@ptrblck After viewing the shape of the policy_net(state_batch), it would appear that the number of outputs were hard-coded to 2.
policy_net(state_batch)