Getting the fx graph of submodules, instead of 'call_module' nodes?

I’m trying to figure out how to always get the full fx graph of the module, including all the nodes in its submodules, instead of getting ‘call_module’ nodes on children… or, alternatively, how to then compile the children?

When I try a toy example, I do get what I want - a flat graph of function calls, with no ‘call_module’ nodes. However, when I load something more complex, like, say, Bert-Tiny, all I get is a graph of ‘call_module’ nodes calling sub-modules.

Toy example:

  import torch                                                                                                                                                      
                                                                                                                                                                    
  class GrandChildModule(torch.nn.Module):                                                                                                                          
      def forward(self, x):                                                                                                                                         
          return 5 * x                                                                                                                                              
                                                                                                                                                                    
  class ChildModule(torch.nn.Module):                                                                                                                               
      def __init__(self):                                                                                                                                           
          super().__init__()                                                                                                                                        
          self.child = GrandChildModule()                                                                                                                           
                                                         
      def forward(self, x):                                                                                                                                         
          return self.child(x)                                                                                                                                      
                                                                                                                                                                    
  class TopModule(torch.nn.Module):                                                                                                                                 
      def __init__(self):                                                                                                                                           
          super().__init__()                                                                                                                                        
          self.child1 = ChildModule()                                                                                                                               
          self.child2 = ChildModule()                                                                                                                               
                                                                                                                                                                    
      def forward(self, x):                                                                                                                                         
          return self.child1(1 + x) + self.child2(2 + x)                                                                                                            
                                                                                                                                                                    
  def my_compiler(gm: torch.fx.GraphModule, example_inputs):                                                                                                        
      gm.graph.print_tabular()                                                                                                                                      
                                                                                                                                                                    
      for node in gm.graph.nodes:                                                                                                                                   
          if node.op == "call_module":                                                                                                                              
              print("Calling: ", node.target)                                                                                                                       
      return gm.forward  # return a python callable  

There’s a bunch of hierarchy in there, however I get a nicely flattened fx graph:

opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    x       x                        ()            {}
call_function  add     <built-in function add>  (1, x)        {}
call_function  mul     <built-in function mul>  (5, add)      {}
call_function  add_1   <built-in function add>  (2, x)        {}
call_function  mul_1   <built-in function mul>  (5, add_1)    {}
call_function  add_2   <built-in function add>  (mul, mul_1)  {}
output         output  output                   ((add_2,),)   {}

If I try the same on a single-layer bert-tiny:

from transformers import BertModel, BertConfig                                                                                                                    
                                                                                                                                                                    
  cfg = BertConfig.from_pretrained("prajjwal1/bert-tiny")                                                                                                           
  cfg.num_hidden_layers = 1                                                                                                                                         
  model = BertModel(config=cfg)                                                                                                                                     
  comp = torch.compile(model, backend=my_compiler)                                                                                                                  
                                                                                                                                                                    
  comp(torch.randint(0, 20000, (1, 128)), torch.randint(0, 1, (1, 128))) 

…I get a few built-in function calls, but the majority of the nodes are ‘call_modules’:

Calling:  self_embeddings_word_embeddings
Calling:  self_embeddings_token_type_embeddings
Calling:  self_embeddings_position_embeddings
Calling:  self_embeddings_LayerNorm
Calling:  self_embeddings_dropout
Calling:  self_encoder_layer_0_attention_self_query
Calling:  self_encoder_layer_0_attention_self_key
Calling:  self_encoder_layer_0_attention_self_value
Calling:  self_encoder_layer_0_attention_self_dropout
Calling:  self_encoder_layer_0_attention_output_dense
Calling:  self_encoder_layer_0_attention_output_dropout
Calling:  self_encoder_layer_0_attention_output_LayerNorm
Calling:  self_encoder_layer_0_intermediate_dense
Calling:  self_encoder_layer_0_output_dense
Calling:  self_encoder_layer_0_output_dropout
Calling:  self_encoder_layer_0_output_LayerNorm
Calling:  self_pooler_dense
Calling:  self_pooler_activation

Is it that TorchDynamo can’t generate graphs for these modules yet, or is it something that I need to turn on / force to get it to go deeper?