In profiling mode, prim::DifferentiableGraph is present in the graph of a scripted module but not the same traced module

Hi, I have a module composed of conv2d and relu.

In profiling mode, we can find one prim::DifferentiableGraph node in the scripted module. However, no prim::DifferentiableGraph is found if we trace this module.

Could someone explain to me why there is this different behaviour between script and trace?

The code:


from __future__ import division

import argparse

import torch

import torch.nn as nn

class MyModule(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):

        super(MyModule, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)

        self.relu = nn.ReLU()

    def forward(self, x):

        return self.relu(self.conv(x))

def test(mode):

    print("*" * 10, mode, "*" * 10)

    ConvRelu = MyModule(3, 32, kernel_size = 3, stride = 1)

    x = torch.randn((1, 3, 8, 8))

    x.requires_grad = True

    if mode == 'script':

        m = torch.jit.script(ConvRelu)

    else:

        m = torch.jit.trace(ConvRelu, x)

    print('Conv2d+Relu Graph:\n', m.graph_for(x))

    print('Conv2d+Relu Graph:\n', m.graph_for(x))

if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    parser.add_argument('--mode', '-m', required=True, choices=['script', 'trace'], help="to script or to trace the module")

    args = parser.parse_args()

    torch._C._jit_set_profiling_mode(True)

    torch._C._jit_set_profiling_executor(True)

    test(args.mode)

Script:


python script_trace.py -m script


********** script **********

Conv2d+Relu Graph:

 graph(%self : __torch__.MyModule,

      %x.1 : Tensor):

  %2 : int[] = prim::Constant[value=[0, 0]]()

  %3 : int[] = prim::Constant[value=[1, 1]]()

  %4 : int = prim::Constant[value=1]() # /home/pytorch/torch/nn/modules/conv.py:343:47

  %5 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self)

  %6 : Tensor = prim::GetAttr[name="weight"](%5)

  %7 : Tensor? = prim::GetAttr[name="bias"](%5)

  %8 : Tensor = prim::profile(%x.1)

  %9 : Tensor = prim::profile(%6)

  %10 : Tensor = aten::conv2d(%8, %9, %7, %3, %2, %3, %4) # /home/pytorch/torch/nn/modules/conv.py:345:15

  %11 : Tensor = prim::profile(%10)

  %result.2 : Tensor = aten::relu(%11) # /home/pytorch/torch/nn/functional.py:1063:17

  %13 : Tensor = prim::profile(%result.2)

   = prim::profile()

  return (%13)

Conv2d+Relu Graph:

 graph(%self : __torch__.MyModule,

      %x.1 : Tensor):

  %4 : int = prim::Constant[value=1]() # /home/pytorch/torch/nn/modules/conv.py:343:47

  %3 : int[] = prim::Constant[value=[1, 1]]()

  %2 : int[] = prim::Constant[value=[0, 0]]()

  %21 : int = prim::BailoutTemplate_0()

  %18 : Float(1, 3, 8, 8) = prim::BailOut[index=0](%21, %x.1, %self)

  %5 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self)

  %6 : Tensor = prim::GetAttr[name="weight"](%5)

  %19 : Float(32, 3, 3, 3) = prim::BailOut[index=1](%21, %6, %5, %18)

  %7 : Tensor? = prim::GetAttr[name="bias"](%5)

  %10 : Tensor = aten::conv2d(%18, %19, %7, %3, %2, %3, %4) # /home/pytorch/torch/nn/modules/conv.py:345:15

  %20 : Float(1, 32, 6, 6) = prim::BailOut[index=2](%21, %10)

  %result.2 : Float(1, 32, 6, 6) = prim::DifferentiableGraph_1(%20)

  return (%result.2)

with prim::BailoutTemplate_0 = graph(%self : __torch__.MyModule,

      %x.1 : Tensor):

  %2 : Float(1, 3, 8, 8) = prim::BailOut[index=0](%x.1, %self)

  %3 : int[] = prim::Constant[value=[0, 0]]()

  %4 : int[] = prim::Constant[value=[1, 1]]()

  %5 : int = prim::Constant[value=1]() # /home/pytorch/torch/nn/modules/conv.py:343:47

  %6 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self)

  %7 : Tensor = prim::GetAttr[name="weight"](%6)

  %8 : Float(32, 3, 3, 3) = prim::BailOut[index=1](%7, %6, %2)

  %9 : Tensor? = prim::GetAttr[name="bias"](%6)

  %10 : Tensor = aten::conv2d(%2, %8, %9, %4, %3, %4, %5) # /home/pytorch/torch/nn/modules/conv.py:345:15

  %11 : Float(1, 32, 6, 6) = prim::BailOut[index=2](%10)

  %result.2 : Float(1, 32, 6, 6) = aten::relu(%11) # /home/pytorch/torch/nn/functional.py:1063:17

  return (%result.2)

with prim::DifferentiableGraph_1 = graph(%0 : Float(1, 32, 6, 6)):

  %result.3 : Float(1, 32, 6, 6) = aten::relu(%0) # /home/pytorch/torch/nn/functional.py:1063:17

  return (%result.3)

We can find one prim::DifferentiableGraph node when we print the graph for the second time.

Trace:


python script_trace.py -m trace


********** trace **********

Conv2d+Relu Graph:

 graph(%self.1 : __torch__.MyModule,

      %input.1 : Tensor):

  %7 : None = prim::Constant(), scope: __module.conv

  %6 : int[] = prim::Constant[value=[1, 1]]()

  %5 : int[] = prim::Constant[value=[0, 0]]()

  %4 : bool = prim::Constant[value=0](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %3 : int = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %2 : bool = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %23 : int = prim::BailoutTemplate_0()

  %20 : Float(1, 3, 8, 8) = prim::BailOut[index=0](%23, %input.1, %self.1)

  %8 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)

  %9 : Tensor = prim::GetAttr[name="weight"](%8)

  %21 : Float(32, 3, 3, 3) = prim::BailOut[index=1](%23, %9, %20)

  %input : Tensor = aten::_convolution(%20, %21, %7, %6, %5, %6, %4, %5, %3, %4, %4, %2), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %22 : Float(1, 32, 6, 6) = prim::BailOut[index=2](%23, %input)

  %14 : Float(1, 32, 6, 6) = aten::relu(%22), scope: __module.relu # /home/pytorch/torch/nn/functional.py:1063:0

  return (%14)

with prim::BailoutTemplate_0 = graph(%self.1 : __torch__.MyModule,

      %input.1 : Tensor):

  %2 : Float(1, 3, 8, 8) = prim::BailOut[index=0](%input.1, %self.1)

  %3 : bool = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %4 : int = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %5 : bool = prim::Constant[value=0](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %6 : int[] = prim::Constant[value=[0, 0]]()

  %7 : int[] = prim::Constant[value=[1, 1]]()

  %8 : None = prim::Constant(), scope: __module.conv

  %9 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)

  %10 : Tensor = prim::GetAttr[name="weight"](%9)

  %11 : Float(32, 3, 3, 3) = prim::BailOut[index=1](%10, %2)

  %input : Tensor = aten::_convolution(%2, %11, %8, %7, %6, %7, %5, %6, %4, %5, %5, %3), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %13 : Float(1, 32, 6, 6) = prim::BailOut[index=2](%input)

  %14 : Float(1, 32, 6, 6) = aten::relu(%13), scope: __module.relu # /home/pytorch/torch/nn/functional.py:1063:0

  return (%14)

Conv2d+Relu Graph:

 graph(%self.1 : __torch__.MyModule,

      %input.1 : Tensor):

  %7 : None = prim::Constant(), scope: __module.conv

  %6 : int[] = prim::Constant[value=[1, 1]]()

  %5 : int[] = prim::Constant[value=[0, 0]]()

  %4 : bool = prim::Constant[value=0](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %3 : int = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %2 : bool = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %23 : int = prim::BailoutTemplate_0()

  %20 : Float(1, 3, 8, 8) = prim::BailOut[index=0](%23, %input.1, %self.1)

  %8 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)

  %9 : Tensor = prim::GetAttr[name="weight"](%8)

  %21 : Float(32, 3, 3, 3) = prim::BailOut[index=1](%23, %9, %20)

  %input : Tensor = aten::_convolution(%20, %21, %7, %6, %5, %6, %4, %5, %3, %4, %4, %2), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %22 : Float(1, 32, 6, 6) = prim::BailOut[index=2](%23, %input)

  %14 : Float(1, 32, 6, 6) = aten::relu(%22), scope: __module.relu # /home/pytorch/torch/nn/functional.py:1063:0

  return (%14)

with prim::BailoutTemplate_0 = graph(%self.1 : __torch__.MyModule,

      %input.1 : Tensor):

  %2 : Float(1, 3, 8, 8) = prim::BailOut[index=0](%input.1, %self.1)

  %3 : bool = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %4 : int = prim::Constant[value=1](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %5 : bool = prim::Constant[value=0](), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %6 : int[] = prim::Constant[value=[0, 0]]()

  %7 : int[] = prim::Constant[value=[1, 1]]()

  %8 : None = prim::Constant(), scope: __module.conv

  %9 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)

  %10 : Tensor = prim::GetAttr[name="weight"](%9)

  %11 : Float(32, 3, 3, 3) = prim::BailOut[index=1](%10, %2)

  %input : Tensor = aten::_convolution(%2, %11, %8, %7, %6, %7, %5, %6, %4, %5, %5, %3), scope: __module.conv # /home/pytorch/torch/nn/modules/conv.py:346:0

  %13 : Float(1, 32, 6, 6) = prim::BailOut[index=2](%input)

  %14 : Float(1, 32, 6, 6) = aten::relu(%13), scope: __module.relu # /home/pytorch/torch/nn/functional.py:1063:0

  return (%14)

In this case, there is no prim::DifferentiableGraph node in the graph.

PyTorch commit:


commit b58f89b2e4b4a6dc9fbc0c00e608de0f4db52267

changes made to the threshold:


diff --git a/torch/csrc/jit/runtime/graph_executor_impl.h b/torch/csrc/jit/runtime/graph_executor_impl.h

index a2fd10c..f4f43e2 100644

--- a/torch/csrc/jit/runtime/graph_executor_impl.h

+++ b/torch/csrc/jit/runtime/graph_executor_impl.h

@@ -40,8 +40,8 @@ bool getAutodiffSubgraphInlining();

 

 // Tunable parameters for deciding when to create/keep subgraphs of

 // differentiable code

-const size_t autodiffSubgraphNodeThreshold = 2;

-const size_t autodiffSubgraphInlineThreshold = 5;

+const size_t autodiffSubgraphNodeThreshold = 1;

+const size_t autodiffSubgraphInlineThreshold = 1;

Thanks.