Hi, I have a module composed of conv2d and relu.
In profiling mode, we can find one prim::DifferentiableGraph
node in the scripted module. However, no prim::DifferentiableGraph
is found if we trace this module.
Could someone explain to me why there is this different behaviour between script and trace?
The code:
from __future__ import division
import argparse
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(MyModule, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
def test(mode):
print("*" * 10, mode, "*" * 10)
ConvRelu = MyModule(3, 32, kernel_size = 3, stride = 1)
x = torch.randn((1, 3, 8, 8))
x.requires_grad = True
if mode == 'script':
m = torch.jit.script(ConvRelu)
else:
m = torch.jit.trace(ConvRelu, x)
print('Conv2d+Relu Graph:\n', m.graph_for(x))
print('Conv2d+Relu Graph:\n', m.graph_for(x))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mode', '-m', required=True, choices=['script', 'trace'], help="to script or to trace the module")
args = parser.parse_args()
torch._C._jit_set_profiling_mode(True)
torch._C._jit_set_profiling_executor(True)
test(args.mode)
Script:
python script_trace.py -m script
********** script **********
Conv2d+Relu Graph:
graph(%self : __torch__.MyModule,
%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[0, 0]]()
%3 : int[] = prim::Constant[value=[1, 1]]()
%4 : int = prim::Constant[value=1]() # /home/pytorch/torch/nn/modules/conv.py:343:47
%5 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self)
%6 : Tensor = prim::GetAttr[name="weight"](%5)
%7 : Tensor? = prim::GetAttr[name="bias"](%5)
%8 : Tensor = prim::profile(%x.1)
%9 : Tensor = prim::profile(%6)
%10 : Tensor = aten::conv2d(%8, %9, %7, %3, %2, %3, %4) # /home/pytorch/torch/nn/modules/conv.py:345:15
%11 : Tensor = prim::profile(%10)
%result.2 : Tensor = aten::relu(%11) # /home/pytorch/torch/nn/functional.py:1063:17
%13 : Tensor = prim::profile(%result.2)
= prim::profile()
return (%13)
Conv2d+Relu Graph:
graph(%self : __torch__.MyModule,
%x.1 : Tensor):
%4 : int = prim::Constant[value=1]() # /home/pytorch/torch/nn/modules/conv.py:343:47
%3 : int[] = prim::Constant[value=[1, 1]]()
%2 : int[] = prim::Constant[value=[0, 0]]()
%21 : int = prim::BailoutTemplate_0()
%18 : Float(1, 3, 8, 8) = prim::BailOut[index=0](%21, %x.1, %self)
%5 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self)
%6 : Tensor = prim::GetAttr[name="weight"](%5)
%19 : Float(32, 3, 3, 3) = prim::BailOut[index=1](%21, %6, %5, %18)
%7 : Tensor? = prim::GetAttr[name="bias"](%5)
%10 : Tensor = aten::conv2d(%18, %19, %7, %3, %2, %3, %4) # /home/pytorch/torch/nn/modules/conv.py:345:15
%20 : Float(1, 32, 6, 6) = prim::BailOut[index=2](%21, %10)
%result.2 : Float(1, 32, 6, 6) = prim::DifferentiableGraph_1(%20)
return (%result.2)
with prim::BailoutTemplate_0 = graph(%self : __torch__.MyModule,
%x.1 : Tensor):
%2 : Float(1, 3, 8, 8) = prim::BailOut[index=0](%x.1, %self)
%3 : int[] = prim::Constant[value=[0, 0]]()
%4 : int[] = prim::Constant[value=[1, 1]]()
%5 : int = prim::Constant[value=1]() # /home/pytorch/torch/nn/modules/conv.py:343:47
%6 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self)
%7 : Tensor = prim::GetAttr[name="weight"](%6)
%8 : Float(32, 3, 3, 3) = prim::BailOut[index=1](%7, %6, %2)
%9 : Tensor? = prim::GetAttr[name="bias"](%6)
%10 : Tensor = aten::conv2d(%2, %8, %9, %4, %3, %4, %5) # /home/pytorch/torch/nn/modules/conv.py:345:15
%11 : Float(1, 32, 6, 6) = prim::BailOut[index=2](%10)
%result.2 : Float(1, 32, 6, 6) = aten::relu(%11) # /home/pytorch/torch/nn/functional.py:1063:17
return (%result.2)
with prim::DifferentiableGraph_1 = graph(%0 : Float(1, 32, 6, 6)):
%result.3 : Float(1, 32, 6, 6) = aten::relu(%0) # /home/pytorch/torch/nn/functional.py:1063:17
return (%result.3)
We can find one prim::DifferentiableGraph
node when we print the graph for the second time.
Trace:
python script_trace.py -m trace
********** trace **********
Conv2d+Relu Graph:
graph(%self.1 : __torch__.MyModule,
%input.1 : Tensor):
%7 : None = prim::Constant(), scope: __module.conv
%6 : int[] = prim::Constant[value=[1, 1]]()
%5 : int[] = prim::Constant[value=[0, 0]]()
%4 : bool = prim::Constant[value=0](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%3 : int = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%2 : bool = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%23 : int = prim::BailoutTemplate_0()
%20 : Float(1, 3, 8, 8) = prim::BailOut[index=0](%23, %input.1, %self.1)
%8 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
%9 : Tensor = prim::GetAttr[name="weight"](%8)
%21 : Float(32, 3, 3, 3) = prim::BailOut[index=1](%23, %9, %20)
%input : Tensor = aten::_convolution(%20, %21, %7, %6, %5, %6, %4, %5, %3, %4, %4, %2), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%22 : Float(1, 32, 6, 6) = prim::BailOut[index=2](%23, %input)
%14 : Float(1, 32, 6, 6) = aten::relu(%22), scope: __module.relu # /home/pytorch/torch/nn/functional.py:1063:0
return (%14)
with prim::BailoutTemplate_0 = graph(%self.1 : __torch__.MyModule,
%input.1 : Tensor):
%2 : Float(1, 3, 8, 8) = prim::BailOut[index=0](%input.1, %self.1)
%3 : bool = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%4 : int = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%5 : bool = prim::Constant[value=0](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%6 : int[] = prim::Constant[value=[0, 0]]()
%7 : int[] = prim::Constant[value=[1, 1]]()
%8 : None = prim::Constant(), scope: __module.conv
%9 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
%10 : Tensor = prim::GetAttr[name="weight"](%9)
%11 : Float(32, 3, 3, 3) = prim::BailOut[index=1](%10, %2)
%input : Tensor = aten::_convolution(%2, %11, %8, %7, %6, %7, %5, %6, %4, %5, %5, %3), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%13 : Float(1, 32, 6, 6) = prim::BailOut[index=2](%input)
%14 : Float(1, 32, 6, 6) = aten::relu(%13), scope: __module.relu # /home/pytorch/torch/nn/functional.py:1063:0
return (%14)
Conv2d+Relu Graph:
graph(%self.1 : __torch__.MyModule,
%input.1 : Tensor):
%7 : None = prim::Constant(), scope: __module.conv
%6 : int[] = prim::Constant[value=[1, 1]]()
%5 : int[] = prim::Constant[value=[0, 0]]()
%4 : bool = prim::Constant[value=0](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%3 : int = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%2 : bool = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%23 : int = prim::BailoutTemplate_0()
%20 : Float(1, 3, 8, 8) = prim::BailOut[index=0](%23, %input.1, %self.1)
%8 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
%9 : Tensor = prim::GetAttr[name="weight"](%8)
%21 : Float(32, 3, 3, 3) = prim::BailOut[index=1](%23, %9, %20)
%input : Tensor = aten::_convolution(%20, %21, %7, %6, %5, %6, %4, %5, %3, %4, %4, %2), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%22 : Float(1, 32, 6, 6) = prim::BailOut[index=2](%23, %input)
%14 : Float(1, 32, 6, 6) = aten::relu(%22), scope: __module.relu # /home/pytorch/torch/nn/functional.py:1063:0
return (%14)
with prim::BailoutTemplate_0 = graph(%self.1 : __torch__.MyModule,
%input.1 : Tensor):
%2 : Float(1, 3, 8, 8) = prim::BailOut[index=0](%input.1, %self.1)
%3 : bool = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%4 : int = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%5 : bool = prim::Constant[value=0](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%6 : int[] = prim::Constant[value=[0, 0]]()
%7 : int[] = prim::Constant[value=[1, 1]]()
%8 : None = prim::Constant(), scope: __module.conv
%9 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
%10 : Tensor = prim::GetAttr[name="weight"](%9)
%11 : Float(32, 3, 3, 3) = prim::BailOut[index=1](%10, %2)
%input : Tensor = aten::_convolution(%2, %11, %8, %7, %6, %7, %5, %6, %4, %5, %5, %3), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0
%13 : Float(1, 32, 6, 6) = prim::BailOut[index=2](%input)
%14 : Float(1, 32, 6, 6) = aten::relu(%13), scope: __module.relu # /home/pytorch/torch/nn/functional.py:1063:0
return (%14)
In this case, there is no prim::DifferentiableGraph
node in the graph.
PyTorch commit:
commit b58f89b2e4b4a6dc9fbc0c00e608de0f4db52267
changes made to the threshold:
diff --git a/torch/csrc/jit/runtime/graph_executor_impl.h b/torch/csrc/jit/runtime/graph_executor_impl.h
index a2fd10c..f4f43e2 100644
--- a/torch/csrc/jit/runtime/graph_executor_impl.h
+++ b/torch/csrc/jit/runtime/graph_executor_impl.h
@@ -40,8 +40,8 @@ bool getAutodiffSubgraphInlining();
// Tunable parameters for deciding when to create/keep subgraphs of
// differentiable code
-const size_t autodiffSubgraphNodeThreshold = 2;
-const size_t autodiffSubgraphInlineThreshold = 5;
+const size_t autodiffSubgraphNodeThreshold = 1;
+const size_t autodiffSubgraphInlineThreshold = 1;
Thanks.