I’m trying to figure out how to always get the full fx graph of the module, including all the nodes in its submodules, instead of getting ‘call_module’ nodes on children… or, alternatively, how to then compile the children?
When I try a toy example, I do get what I want - a flat graph of function calls, with no ‘call_module’ nodes. However, when I load something more complex, like, say, Bert-Tiny, all I get is a graph of ‘call_module’ nodes calling sub-modules.
Toy example:
import torch
class GrandChildModule(torch.nn.Module):
def forward(self, x):
return 5 * x
class ChildModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.child = GrandChildModule()
def forward(self, x):
return self.child(x)
class TopModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.child1 = ChildModule()
self.child2 = ChildModule()
def forward(self, x):
return self.child1(1 + x) + self.child2(2 + x)
def my_compiler(gm: torch.fx.GraphModule, example_inputs):
gm.graph.print_tabular()
for node in gm.graph.nodes:
if node.op == "call_module":
print("Calling: ", node.target)
return gm.forward # return a python callable
There’s a bunch of hierarchy in there, however I get a nicely flattened fx graph:
opcode name target args kwargs
------------- ------ ----------------------- ------------ --------
placeholder x x () {}
call_function add <built-in function add> (1, x) {}
call_function mul <built-in function mul> (5, add) {}
call_function add_1 <built-in function add> (2, x) {}
call_function mul_1 <built-in function mul> (5, add_1) {}
call_function add_2 <built-in function add> (mul, mul_1) {}
output output output ((add_2,),) {}
If I try the same on a single-layer bert-tiny:
from transformers import BertModel, BertConfig
cfg = BertConfig.from_pretrained("prajjwal1/bert-tiny")
cfg.num_hidden_layers = 1
model = BertModel(config=cfg)
comp = torch.compile(model, backend=my_compiler)
comp(torch.randint(0, 20000, (1, 128)), torch.randint(0, 1, (1, 128)))
…I get a few built-in function calls, but the majority of the nodes are ‘call_modules’:
Calling: self_embeddings_word_embeddings
Calling: self_embeddings_token_type_embeddings
Calling: self_embeddings_position_embeddings
Calling: self_embeddings_LayerNorm
Calling: self_embeddings_dropout
Calling: self_encoder_layer_0_attention_self_query
Calling: self_encoder_layer_0_attention_self_key
Calling: self_encoder_layer_0_attention_self_value
Calling: self_encoder_layer_0_attention_self_dropout
Calling: self_encoder_layer_0_attention_output_dense
Calling: self_encoder_layer_0_attention_output_dropout
Calling: self_encoder_layer_0_attention_output_LayerNorm
Calling: self_encoder_layer_0_intermediate_dense
Calling: self_encoder_layer_0_output_dense
Calling: self_encoder_layer_0_output_dropout
Calling: self_encoder_layer_0_output_LayerNorm
Calling: self_pooler_dense
Calling: self_pooler_activation
Is it that TorchDynamo can’t generate graphs for these modules yet, or is it something that I need to turn on / force to get it to go deeper?