How to add input shape dimension hint to an optimized graph?

Hi, I’m trying to get an optimized graph from graph executor via ScriptModule’s graph_for(...) method. A simple test case is below:

import torch

conv = torch.nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3)

#  To avoid dealing with prim::Bailout stuff
torch._C._jit_set_profiling_executor(False)

inp = torch.rand(1, 3, 224, 224)
trace = torch.jit.trace(conv, inp).eval()

print(trace.graph_for(inp))

But it seems graph executor erases shape information on construction, so the output of graph_for has shape information removed.

graph(%self : __torch__.torch.nn.modules.module.Module,
      %input : Float(*, *, *, *)):
  %23 : int[] = prim::Constant[value=[0, 0]]()
  %22 : int[] = prim::Constant[value=[1, 1]]()
  %4 : int = prim::Constant[value=1]() # /home/masa/projects/deep/pytorch/torch/nn/modules/conv.py:345:0
  %13 : bool = prim::Constant[value=0]() # /home/masa/projects/deep/pytorch/torch/nn/modules/conv.py:345:0
  %20 : bool = prim::Constant[value=1]() # /home/masa/projects/deep/pytorch/torch/nn/modules/conv.py:345:0
  %2 : Float(*) = prim::GetAttr[name="bias"](%self)
  %3 : Float(*, *, *, *) = prim::GetAttr[name="weight"](%self)
  %21 : Float(*, *, *, *) = aten::_convolution(%input, %3, %2, %22, %23, %22, %13, %23, %4, %13, %13, %20) # /home/masa/projects/deep/pytorch/torch/nn/modules/conv.py:345:0
  return (%21)

If I know that my input shape is always fixed, is there a way to add explicit shape information? I know that the output of trace has shape information preserved, but I want an optimized graph available via graph_for(...).

Could you please elaborate on your use case? The exact static shapes were added to enable backends to generate more efficient kernels. I’m not sure they will be added to serialization or any other public interface.

WIthout,
torch._C._jit_set_profiling_executor(False) we don’t capture the exact shapes only tensor ranks. Even the rank information, I believe, considered internal for use by backends and optimization passes.

ok, my use case is to translate Torchscript IR to TVM compiler’s IR. I have a repo https://github.com/masahi/torchscript-to-tvm which demonstrates translating torch models to TVM, compile and run it under TVM and get the identical output as torch.

Right now I take the output of torch.jit.trace(...) which has shape information attached, and for each torch operator node I translate it to corresponding one in TVM. Since TVM is mostly a static compiler, shape information is required.

TVM has its own set of optimization passes, so it is no problem to take the unoptimized input torch IR. Currently I’m applying the Torchscript inlining pass to remove prim::CallMethod wrapping, but it seems rather ad hoc to me and I rather want to apply other optimization passes in torch as well.

I know I can apply each optimization passes manually (and I do for the inline pass), but the API prefix torch._C._jit_pass* suggests they are not “officially” supported, so I’m not sure if I want to use them directly. Since I discovered that I can access the optimized graph via graph_for(...) method, I’m looking to see if this is something I can use.

I disabled the profiling executor because it adds prim::Bailout and prim::BailoutTemplate nodes for which I have no idea how to translate to TVM. Since input shape is static in TVM, I think they are not relevant to my use case, so I don’t want to see them in the input IR.

I would suggest that you run it with the profiling executor a few times with inputs that cover the different Tensor dimensions you expect to use, and then add a pass to remove Bailout nodes. This should give you a Graph with shape information to maximum generality.

There are also a few passes that will be landed shortly in Pytorch JIT that should help with this conversion. Freezing and a functionalization pass. I will comment here when they are landed.

1 Like

Thanks, if removing bailout is possible that would definitely work for me. At a quick glance there is no pass to remove Bailout nodes in torch, but I’ll try if I can do it from “userland”.

Another API I’m interested in is _propagate_shapes function:

Do you know the use case of this function and if this is something I should take a look?

This API is only useful with the executor that isn’t the profiled executor, so I don’t think it applies. Yes, the pass doesn’t exactly exist as you need it. The logic should pretty much be the same as right here, except you are always removing the guards.

1 Like