When tracing a model where some of the logic depends on
self.training, we can only have one of the branches. Here’s a minimal example:
from torch import nn import torch.fx tracer = torch.fx.Tracer() class MyModule(nn.Module): def __init__(self): super().__init__() self.path = False def forward(self, x): if self.training: return x + 1 else: return x module = MyModule() module.eval() graph = tracer.trace(module) print(graph) graph_module = torch.fx.GraphModule(module, graph) graph_module.train() # This has no effect on the outcome print(graph_module(torch.zeros(1)))
As you can see,
graph_module.train() does nothing, because the tracing operation completely ignored the
A more realistic example is when we have some sort of dropout we only want to switch on during training.
Any way to somehow build both branches of the control flow into the resulting graph module?