Hi, I’m trying to get an optimized graph from graph executor via ScriptModule’s graph_for(...)
method. A simple test case is below:
import torch
conv = torch.nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3)
# To avoid dealing with prim::Bailout stuff
torch._C._jit_set_profiling_executor(False)
inp = torch.rand(1, 3, 224, 224)
trace = torch.jit.trace(conv, inp).eval()
print(trace.graph_for(inp))
But it seems graph executor erases shape information on construction, so the output of graph_for
has shape information removed.
graph(%self : __torch__.torch.nn.modules.module.Module,
%input : Float(*, *, *, *)):
%23 : int[] = prim::Constant[value=[0, 0]]()
%22 : int[] = prim::Constant[value=[1, 1]]()
%4 : int = prim::Constant[value=1]() # /home/masa/projects/deep/pytorch/torch/nn/modules/conv.py:345:0
%13 : bool = prim::Constant[value=0]() # /home/masa/projects/deep/pytorch/torch/nn/modules/conv.py:345:0
%20 : bool = prim::Constant[value=1]() # /home/masa/projects/deep/pytorch/torch/nn/modules/conv.py:345:0
%2 : Float(*) = prim::GetAttr[name="bias"](%self)
%3 : Float(*, *, *, *) = prim::GetAttr[name="weight"](%self)
%21 : Float(*, *, *, *) = aten::_convolution(%input, %3, %2, %22, %23, %22, %13, %23, %4, %13, %13, %20) # /home/masa/projects/deep/pytorch/torch/nn/modules/conv.py:345:0
return (%21)
If I know that my input shape is always fixed, is there a way to add explicit shape information? I know that the output of trace has shape information preserved, but I want an optimized graph available via graph_for(...)
.