When should I detect a fake tensor mode?

Hello.

I’m writing a constant folding pass of the fx graph that is made by torch.export.export.

When I update a meta data for newly created nodes, I’m using fake mode like below.

new_node.meta["val"] = fake_mode.from_tensor(
            prop_constant_tensor, static_shapes=True
        )

But I couldn’t find the document about where I should get a fake tensor mode. I guess there are two options.

  1. Detect the existing mode
def get_fake_mode(exported_program: ExportedProgram):
    fake_mode = detect_fake_mode(
        tuple(
            node.meta["val"]
            for node in exported_program.graph.nodes
            if node.op == "placeholder"
        )
    )
    assert fake_mode is not None
    return fake_mode
  1. Create new one
def get_fake_mode():
    return FakeTensorMode()

Which one is better? Is it okay for me to create a new FakeTensorMode every time when a fake mode is needed? Or should I detect the fake mode?

detect_fake_mode is better because if the export in question has dynamic shapes, you want to reuse the old FakeTensorMode which has a ShapeEnv tracking the dynamic shapes. If it’s all static shapes and your fake tensor usage doesn’t escape, it doesn’t really matter.

1 Like