How to get output shape of each node from a loaded model, which is saved by jit?

Hello, guys!
I searched in google but there is no solution.

If a model is traced by torch.jit.trace, then saved in disk.
Then I load the model just before, and get its graph by model.forward.graph and torch._C._jit_pass_lower_graph, but the output shapes of nodes in graph are lost, how to get these output shapes of nodes?
Here is an example code:

import torch
import torchvision
from torch._C import _propagate_and_assign_input_shapes
def _model_to_graph(model, args):
    if isinstance(args, torch.Tensor):
        args = (args, )
    graph = model.forward.graph
    method_graph, params = torch._C._jit_pass_lower_graph(graph, model._c)
    in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
    graph = _propagate_and_assign_input_shapes(  method_graph, tuple(in_vars), False, False)
    return graph
traced_model_savepath = 'traced.pt'
model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True)
dummy_input = torch.rand((1,3,224,224))
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save(traced_model_savepath)
# load the saved traced model from disk
loaded_traced_model = torch.jit.load(traced_model_savepath)
graph= _model_to_graph(loaded_traced_model,  dummy_input)
print(graph)

Sadly, the output shape of nodes are lost, for example:

  %input.6 : Tensor = aten::_convolution(%626, %251, %276, %627, %628, %629, %277, %630, %274, %277, %277, %278) # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %input.104 : Tensor = aten::batch_norm(%input.6, %252, %253, %254, %255, %278, %280, %279, %278) # /data0/shareVR/pytorch/learn/torch/nn/functional.py:2051:0

node %input.104 is a batch_norm layer and its output is a `Tensor’, but what’s its output shape? How to get the output shape?

Many thanks!

There is torch._C._jit_pass_complete_shape_analysis(graph, inputs, with_grad:Bool) which works, but is internal.
See issue 39690 for some detail.

Best regards

Thomas

Thanks for you reply, and I will try your answer.

I am sorry for not sure how to use torch._C._jit_pass_complete_shape_analysis.
I found in https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/init.cpp#L359
that torch._C._jit_pass_complete_shape_analysis passes Graph as a shared_ptr.
So I modified my code as:

import torch
import torchvision
import torch.nn as nn
class DemoNet(nn.Module):
    def __init__(self ):

        super(DemoNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(3,32, kernel_size=(3,3))
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3))
        self.relu1 = torch.nn.ReLU(inplace=True)
        self.conv3 = torch.nn.Conv2d(64, 32, kernel_size=(3, 3))

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.relu1(x)
        x = self.conv3(x)
        return x
def _model_to_graph(model, args):
    if isinstance(args, torch.Tensor):
        args = (args, )
    graph = model.forward.graph
    method_graph, params = torch._C._jit_pass_lower_graph(graph, model._c)
    in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
    torch._C._jit_pass_complete_shape_analysis(method_graph, tuple(in_vars), False)
    return method_graph

traced_model_savepath = 'traced.pt'
model = DemoNet()
model.eval()
dummy_input = torch.randn((2,3,224,224))

traced_model = torch.jit.trace(model, dummy_input)
traced_model.save(traced_model_savepath)

# load the saved traced model
loaded_traced_model = torch.jit.load(traced_model_savepath)

graph= _model_to_graph(loaded_traced_model,  dummy_input)
print(graph)

I use a toy DemoNet to simplify output graph.
The output is:

graph(%input.2 : Float(2:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu),
      %44 : Float(32:27, 3:9, 3:3, 3:1, requires_grad=1, device=cpu),
      %45 : Float(32:1, requires_grad=1, device=cpu),
      %46 : Float(64:288, 32:9, 3:3, 3:1, requires_grad=1, device=cpu),
      %47 : Float(64:1, requires_grad=1, device=cpu),
      %48 : Float(32:576, 64:9, 3:3, 3:1, requires_grad=1, device=cpu),
      %49 : Float(32:1, requires_grad=1, device=cpu)):
  %10 : int = prim::Constant[value=0]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %11 : int = prim::Constant[value=1]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %12 : bool = prim::Constant[value=0]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %13 : bool = prim::Constant[value=1]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %16 : int[] = prim::ListConstruct(%11, %11)
  %17 : int[] = prim::ListConstruct(%10, %10)
  %18 : int[] = prim::ListConstruct(%11, %11)
  %19 : int[] = prim::ListConstruct(%10, %10)
  %input0.1 : Float(*, *, *, *, requires_grad=0, device=cpu) = aten::_convolution(%input.2, %44, %45, %16, %17, %18, %12, %19, %11, %12, %12, %13) # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %21 : int = prim::Constant[value=0]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %22 : int = prim::Constant[value=1]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %23 : bool = prim::Constant[value=0]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %24 : bool = prim::Constant[value=1]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %27 : int[] = prim::ListConstruct(%22, %22)
  %28 : int[] = prim::ListConstruct(%21, %21)
  %29 : int[] = prim::ListConstruct(%22, %22)
  %30 : int[] = prim::ListConstruct(%21, %21)
  %input.1 : Float(*, *, *, *, requires_grad=0, device=cpu) = aten::_convolution(%input0.1, %46, %47, %27, %28, %29, %23, %30, %22, %23, %23, %24) # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %32 : Tensor = aten::relu_(%input.1) # /data0/shareVR/pytorch/learn/torch/nn/functional.py:1125:0
  %33 : int = prim::Constant[value=0]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %34 : int = prim::Constant[value=1]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %35 : bool = prim::Constant[value=0]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %36 : bool = prim::Constant[value=1]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  %39 : int[] = prim::ListConstruct(%34, %34)
  %40 : int[] = prim::ListConstruct(%33, %33)
  %41 : int[] = prim::ListConstruct(%34, %34)
  %42 : int[] = prim::ListConstruct(%33, %33)
  %43 : Tensor = aten::_convolution(%32, %48, %49, %39, %40, %41, %35, %42, %34, %35, %35, %36) # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
  return (%43)

It seems that torch._C._jit_pass_complete_shape_analysis can not tell me the exact output shape of
nodes such as: %input0.1 %input.1, %32, %43,

For %input0.1 and %input.1, it just tells me that they are 4-D tensors, not exact shape information;
for %32 and %43, it just tells me they are tensors.

Any suggestion on how to use torch._C._jit_pass_complete_shape_analysis correctly?
Thanks a lot.

So apparently it cannot infer the shape of the convolution (it should when the stride, padding,dilation etc. params are constants). :confused:

So is there a conclusion to this question? I would also like to keep static tensor shape information in TorchScirpt from saved model.