Log_prob calculation process

Make the output of a MLP network as a logits, and then calculate Categorical, then calculate the log_prob, what is the process of calculation and what values are involved in the calculation of the log_prob function?

This question is not related to the distributed package so please change the category.

But I think this is the logic for log_prob that you are looking for: pytorch/categorical.py at master · pytorch/pytorch · GitHub

1 Like