How to get Operator from torch::jit::module

Hi,

How could I get the Operator from torch::jit::module?
I want to traverse the whole module and estimate the costs on every op (since I want to get custom costs result, so using torch.autograd.profiler is not that applicable).

// some_cpp_code.cpp
  <...>
  m.def("get_cost", [](const torch::jit::Module& mod) {
    auto forward_method = mod.get_method("forward");
    assert(forward_method.graph());
    Costs cost;

    // Traverse the graph, is there any more intuitive way? 
    for (Node * n : forward_method.graph()->nodes()) {
      auto kind = n->kind();
      // Calculate the cost
      switch (kind){
          case aten::add:
                cost += my_add_cost( n->inputs(), n->outputs());
          case aten::conv2d:
                cost += my_conv2d_cost(n->inputs(), n->outputs());
       //...
    }
  });

The above way could (possibly) work, but it makes the code messy and actually I don’t need that much information. I just need the Operator and inputs and outputs. The ideal function call would be something like :

Costs my_cost(operator, inputs, outputs){
// Using a map to register correct cost function
}

// cost function
Costs my_conv2d_cost(operator, inputs, outputs){
// Get other information about this op
auto padding = inputs[4]; // Is this right?
}

I have 3 questions:

  1. Is PyTorch offers any way to traverse the whole module and get every op? The above code would have to use a switch case to judge every case, this is quite annoying. Also, it will encounter many nodes like prim::constant , and those nodes are kind of not interesting.
  2. Is getOperation(const Node* node = nullptr) function the one I need? It returns the operation, but actually I don’t quite understand how to use it , it doesn’t filter out the prim::constant node as stated before, right?
  3. How could I get the correct inputs, like the code shown of my_conv2d_cost, is this the right way?
  1. No, but you can use Node::maybeOperator to skip nodes that don’t have operators (like prim nodes).
  2. No; getOperation returns an Operation, which is an alias for std::function<void(Stack*)>. This is a callable function that executes the operation, but will not be useful for cost analysis. Consider using getOperator or maybeOperator; these functions return an Operator that has an associated schema (Operator::schema()) that can be used to determine which operator it is (essentially the same as n->kind() from your example).
  3. Take a look at the functions in jit/ir.h; you will find functions there that help you access node inputs and their types. However, in the case of tensor types, you might not always find complete shape and dtype information (it depends on how the graph is produced and when your analysis code runs).

@SplitInfinity

Thanks for the reply! Could you give me some hints about this line?

However, in the case of tensor types, you might not always find complete shape and dtype information (it depends on how the graph is produced and when your analysis code runs).

I have also found that it is hard to get the info from incomplete tensor ( conv2d’s weight tensor for example). Could you give me some hints on how pytorch handles this case? Rerun the full module seems not a good idea for large module case in my opinion.

Unfortunately, there is no way to propagate shape information other than running the model. We are working on ways to let users add this information to graphs directly but this won’t be released anytime soon.

Thanks for your work!

We are working on ways to let users add this information to graphs directly

Do you mean the structured kernel definitions? If so then maybe I could try to refer to existing PRs?

No, that is different. There is ongoing work on a tensor DSL but unfortunately there’s no public RFC or something I can share.