I’m working in Reinforcement Learning and was pretty excited about the potential benefits of torch.compile (jax.jit can speed up RL quite dramatically). I’ve noticed however that sampling from a torch probability distribution breaks the FX Graph. That’s a bummer for RL, where the policy usually parameterizes a probability distribution.

There are two thing going on here: First, you can’t backpropagate through .sample(). But even if you could, the result of Categorical.sample() is
an integer, and you can’t backpropagate through integers (because they are
discrete and therefore not usefully differentiable).

Depending on your specific use case, you may consider backpropagating
through .log_prob (action) (where action is the result of .sample()),
which does make sense (even for Categorical) and is permitted.

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

Such a scheme may (or may not) make sense for your use case.

thanks for the suggestion! Indeed you are correct that one has to backpropagate using the log-probabilities of the action, which is called the policy gradient. Your example is the REINFORCE algorithm.

However, please note that I do not intend to backpropagate through Categorical.sample(). Instead, I’m using a more elaborate version of the algorithm you gave in your example: SAC-discrete, an off-policy actor-critic algorithm where the actor (=policy) is updated similarly to your example.

thank you very much for answering so quickly. I really appreciate it.

Here’s the link to a repro with a relevant example:

Contains the following:

A simplified example using fake data with only torch dependencies.

A realistic example with the full algorithm I’m using (SAC-discrete).

dynamo.explain() printouts in the README.

I’ve tried using no_grad around the sampling, but that doesn’t work. An option might be to change the structure of the code, i.e. separate action generation (which uses .sample()) from calculating the log-probabilities.

I want to mention that the general structure of this example (sampling from a distribution and then using log_prob to backpropagate) is used widely in deep RL in many policy gradient/actor-critic algorithms (e.g. REINFORCE, A2C, PPO, SAC-discrete).