The torchRL policy implementations always invoke a ProbabilisticActor when using PPO style losses. I see that at some point it calls a method called .get_dist(), which I assume is returning the final distribution layer. As I am not using a standard distribution from the available ones, do I need to add a ‘fake’ get_dist, log_prob method etc onto my actor network to make it compatible?
Is this not recommended?
Further, would it make sense to instead subclass torch.distributions and make my own custom distribution that adheres to their interface?
Hey
For PPO we need to have access to the current log probability of the action given the param configuration as well as the same value with the param configuration used at collection time.
Because of this we rely on the dist.log_prob method which is the PyTorch distribution API (note that this is a mild requirement since PPO relies on policies that are parametric distributions so it will be more often the case than not that your policy has this log_prob method). get_dist on the other hand is TorchRL/TensorDict module API that interfaces a module with a distribution that has a log_prob. What this function does is rather simple and should be easy to reproduce within a custom module