Is there a way to include control flow in traced model?

When tracing a model where some of the logic depends on, we can only have one of the branches. Here’s a minimal example:

from torch import nn
import torch.fx

tracer = torch.fx.Tracer()

class MyModule(nn.Module):
    def __init__(self):
        self.path = False

    def forward(self, x):
            return x + 1
            return x

module = MyModule()
graph = tracer.trace(module)
graph_module = torch.fx.GraphModule(module, graph)

graph_module.train()  # This has no effect on the outcome

As you can see, graph_module.train() does nothing, because the tracing operation completely ignored the branch.

A more realistic example is when we have some sort of dropout we only want to switch on during training.

Any way to somehow build both branches of the control flow into the resulting graph module?