I am sorry for not sure how to use torch._C._jit_pass_complete_shape_analysis.
I found in https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/init.cpp#L359
that torch._C._jit_pass_complete_shape_analysis passes Graph as a shared_ptr.
So I modified my code as:
import torch
import torchvision
import torch.nn as nn
class DemoNet(nn.Module):
def __init__(self ):
super(DemoNet, self).__init__()
self.conv1 = torch.nn.Conv2d(3,32, kernel_size=(3,3))
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3))
self.relu1 = torch.nn.ReLU(inplace=True)
self.conv3 = torch.nn.Conv2d(64, 32, kernel_size=(3, 3))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.relu1(x)
x = self.conv3(x)
return x
def _model_to_graph(model, args):
if isinstance(args, torch.Tensor):
args = (args, )
graph = model.forward.graph
method_graph, params = torch._C._jit_pass_lower_graph(graph, model._c)
in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
torch._C._jit_pass_complete_shape_analysis(method_graph, tuple(in_vars), False)
return method_graph
traced_model_savepath = 'traced.pt'
model = DemoNet()
model.eval()
dummy_input = torch.randn((2,3,224,224))
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save(traced_model_savepath)
# load the saved traced model
loaded_traced_model = torch.jit.load(traced_model_savepath)
graph= _model_to_graph(loaded_traced_model, dummy_input)
print(graph)
I use a toy DemoNet to simplify output graph.
The output is:
graph(%input.2 : Float(2:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu),
%44 : Float(32:27, 3:9, 3:3, 3:1, requires_grad=1, device=cpu),
%45 : Float(32:1, requires_grad=1, device=cpu),
%46 : Float(64:288, 32:9, 3:3, 3:1, requires_grad=1, device=cpu),
%47 : Float(64:1, requires_grad=1, device=cpu),
%48 : Float(32:576, 64:9, 3:3, 3:1, requires_grad=1, device=cpu),
%49 : Float(32:1, requires_grad=1, device=cpu)):
%10 : int = prim::Constant[value=0]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%11 : int = prim::Constant[value=1]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%12 : bool = prim::Constant[value=0]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%13 : bool = prim::Constant[value=1]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%16 : int[] = prim::ListConstruct(%11, %11)
%17 : int[] = prim::ListConstruct(%10, %10)
%18 : int[] = prim::ListConstruct(%11, %11)
%19 : int[] = prim::ListConstruct(%10, %10)
%input0.1 : Float(*, *, *, *, requires_grad=0, device=cpu) = aten::_convolution(%input.2, %44, %45, %16, %17, %18, %12, %19, %11, %12, %12, %13) # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%21 : int = prim::Constant[value=0]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%22 : int = prim::Constant[value=1]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%23 : bool = prim::Constant[value=0]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%24 : bool = prim::Constant[value=1]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%27 : int[] = prim::ListConstruct(%22, %22)
%28 : int[] = prim::ListConstruct(%21, %21)
%29 : int[] = prim::ListConstruct(%22, %22)
%30 : int[] = prim::ListConstruct(%21, %21)
%input.1 : Float(*, *, *, *, requires_grad=0, device=cpu) = aten::_convolution(%input0.1, %46, %47, %27, %28, %29, %23, %30, %22, %23, %23, %24) # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%32 : Tensor = aten::relu_(%input.1) # /data0/shareVR/pytorch/learn/torch/nn/functional.py:1125:0
%33 : int = prim::Constant[value=0]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%34 : int = prim::Constant[value=1]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%35 : bool = prim::Constant[value=0]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%36 : bool = prim::Constant[value=1]() # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
%39 : int[] = prim::ListConstruct(%34, %34)
%40 : int[] = prim::ListConstruct(%33, %33)
%41 : int[] = prim::ListConstruct(%34, %34)
%42 : int[] = prim::ListConstruct(%33, %33)
%43 : Tensor = aten::_convolution(%32, %48, %49, %39, %40, %41, %35, %42, %34, %35, %35, %36) # /data0/shareVR/pytorch/learn/torch/nn/modules/conv.py:416:0
return (%43)
It seems that torch._C._jit_pass_complete_shape_analysis can not tell me the exact output shape of
nodes such as: %input0.1
%input.1
, %32
, %43
,
For %input0.1
and %input.1
, it just tells me that they are 4-D tensors, not exact shape information;
for %32
and %43
, it just tells me they are tensors.
Any suggestion on how to use torch._C._jit_pass_complete_shape_analysis correctly?
Thanks a lot.