Distribution.sample() graph breaks

I’m working in Reinforcement Learning and was pretty excited about the potential benefits of torch.compile (jax.jit can speed up RL quite dramatically). I’ve noticed however that sampling from a torch probability distribution breaks the FX Graph. That’s a bummer for RL, where the policy usually parameterizes a probability distribution.

Here’s an example network:

class Actor(nn.Module):
    def __init__(self, envs):
        super().__init__()
        obs_shape = envs.single_observation_space.shape
        self.conv = nn.Sequential(
            layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
            nn.Flatten(),
        )

        with torch.inference_mode():
            output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]

        self.fc1 = layer_init(nn.Linear(output_dim, 512))
        self.fc_logits = layer_init(nn.Linear(512, envs.single_action_space.n))

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = F.relu(self.fc1(x))
        logits = self.fc_logits(x)

        return logits

    def get_action(self, x):
        logits = self(x / 255.0)
        policy_dist = Categorical(logits=logits)
        action = policy_dist.sample() # <- GRAPH BREAK
        action_probs = policy_dist.probs
        log_prob = F.log_softmax(logits, dim=1)
        return action, log_prob, action_probs

Trying to compile the network or any kind of loss will always break the graph at the indicated line.

As I don’t understand torch.compile very well I was wondering:

  • Will this be fixed in a future update?
  • If not: Is there are workaround I can use?

Thanks in advance for the help and sorry if it’s a bit of a noob question.

Hi @X3N4 could you please share a full repro?

Hi X!

There are two thing going on here: First, you can’t backpropagate through
.sample(). But even if you could, the result of Categorical.sample() is
an integer, and you can’t backpropagate through integers (because they are
discrete and therefore not usefully differentiable).

Depending on your specific use case, you may consider backpropagating
through .log_prob (action) (where action is the result of .sample()),
which does make sense (even for Categorical) and is permitted.

Quoting from the documentation for torch.distributions:

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

Such a scheme may (or may not) make sense for your use case.

Good luck!

K. Frank

Hi KFrank,

thanks for the suggestion! Indeed you are correct that one has to backpropagate using the log-probabilities of the action, which is called the policy gradient. Your example is the REINFORCE algorithm.

However, please note that I do not intend to backpropagate through Categorical.sample(). Instead, I’m using a more elaborate version of the algorithm you gave in your example: SAC-discrete, an off-policy actor-critic algorithm where the actor (=policy) is updated similarly to your example.

Hi @marksaroufim

thank you very much for answering so quickly. I really appreciate it.

Here’s the link to a repro with a relevant example:

Contains the following:

  • A simplified example using fake data with only torch dependencies.
  • A realistic example with the full algorithm I’m using (SAC-discrete).
  • dynamo.explain() printouts in the README.

I’ve tried using no_grad around the sampling, but that doesn’t work. An option might be to change the structure of the code, i.e. separate action generation (which uses .sample()) from calculating the log-probabilities.

I want to mention that the general structure of this example (sampling from a distribution and then using log_prob to backpropagate) is used widely in deep RL in many policy gradient/actor-critic algorithms (e.g. REINFORCE, A2C, PPO, SAC-discrete).

Just want to mention that this was able to fix the mentioned graph breaks.

1 Like