Setting parent and children of nodes with custom params

Hi folks,

I’l looking to take a graph that looks like the below (left side) and turn it into the right side graph to make it fit on a device with small on device memory

I was wondering if there was a recommended way to do such transformations? My output graph is not a linkedlist of nodes as in a ->b → c so I don’t think the append and prepend operations will help my use case. Basically I want to be able to manually set the parent and children of a node, set some parameters for that node/layer and then add a new cat operation and generate a new graph.

Adding a new cat operations seems to be simple: pytorch/torch/fx/experimental/merge_matmul.py at main · pytorch/pytorch · GitHub
Returning a new graph also simple: pytorch/torch/fx/experimental/merge_matmul.py at main · pytorch/pytorch · GitHub
Set parameters: Should I manually just write to the modules dictionary?
Unsure how to set parent and children for nodes, have a node potentially have multiple children

Are any of the public examples suitable as a starting point? @Chillee @James_Reed

There is no direct concept of a “parent” or “child” in FX - the closest analogues are “args” (i.e. what nodes your current node is using) and “users” (i.e. what nodes are using your current node). Furthermore, “users” is a property derived from “args”, so the only thing you should be modifying is what the args to your function are.

That being said, it’s not totally obvious to me what the graph rewrite you want to do is (where does conv3 go?), but I’d imagine it looks a bit like this (assuming you want to replace conv3 with concat?):

  1. create conv2_1 and conv2_2 from conv2 (probably using nn.Module.add_module)
  2. create 2 call_module nodes that take in conv1 as an arg using fx.Graph.call_module (inserting in front of conv3)
  3. create a concat node with conv2_1 and conv2_2 as input
  4. Replace all uses of conv3 with concat using fx.Node.replace_all_uses_with

It might also be easier for you to use the Subgraph Rewriter API: torch.fx — PyTorch 1.8.1 documentation

1 Like