Hi,
How could I get the Operator from torch::jit::module?
I want to traverse the whole module and estimate the costs on every op (since I want to get custom costs result, so using torch.autograd.profiler
is not that applicable).
// some_cpp_code.cpp
<...>
m.def("get_cost", [](const torch::jit::Module& mod) {
auto forward_method = mod.get_method("forward");
assert(forward_method.graph());
Costs cost;
// Traverse the graph, is there any more intuitive way?
for (Node * n : forward_method.graph()->nodes()) {
auto kind = n->kind();
// Calculate the cost
switch (kind){
case aten::add:
cost += my_add_cost( n->inputs(), n->outputs());
case aten::conv2d:
cost += my_conv2d_cost(n->inputs(), n->outputs());
//...
}
});
The above way could (possibly) work, but it makes the code messy and actually I don’t need that much information. I just need the Operator
and inputs
and outputs
. The ideal function call would be something like :
Costs my_cost(operator, inputs, outputs){
// Using a map to register correct cost function
}
// cost function
Costs my_conv2d_cost(operator, inputs, outputs){
// Get other information about this op
auto padding = inputs[4]; // Is this right?
}
I have 3 questions:
- Is PyTorch offers any way to traverse the whole module and get every op? The above code would have to use a
switch
case to judge every case, this is quite annoying. Also, it will encounter many nodes likeprim::constant
, and those nodes are kind of not interesting. - Is
getOperation(const Node* node = nullptr)
function the one I need? It returns the operation, but actually I don’t quite understand how to use it , it doesn’t filter out theprim::constant
node as stated before, right? - How could I get the correct inputs, like the code shown of
my_conv2d_cost
, is this the right way?