Graph_executor behavior change in 1.5

Hi all,

It seems that .graph_for doesn’t seem to run optimizations anymore with PyTorch 1.5.

For example, from the following simple code, graph_for dumps different graphs between PyTorch 1.4 and PyTorch 1.5.

$ cat small.py
import torch

def f(x,y): return x + y * 3

print(torch.jit.script(f).graph_for(torch.rand(2, 2, device='cuda'), torch.rand(2, 2, device='cuda')))

PyTorch 1.4:

$ python small.py
graph(%x.1 : Float(*, *),
      %y.1 : Float(*, *)):
  %6 : Float(*, *) = prim::FusionGroup_0(%x.1, %y.1)
  return (%6)
with prim::FusionGroup_0 = graph(%0 : Float(*, *),
      %4 : Float(*, *)):
  %2 : int = prim::Constant[value=1]()
  %5 : int = prim::Constant[value=3]() # small.py:3:27
  %6 : Float(*, *) = aten::mul(%4, %5) # small.py:3:23
  %3 : Float(*, *) = aten::add(%0, %6, %2) # small.py:3:19
  return (%3)

PyTorch 1.5:

$ python small.py
graph(%x.1 : Tensor,
      %y.1 : Tensor):
  %3 : int = prim::Constant[value=3]() # small.py:3:27
  %2 : int = prim::Constant[value=1]()
  %4 : Tensor = aten::mul(%y.1, %3) # small.py:3:23
  %5 : Tensor = aten::add(%x.1, %4, %2) # small.py:3:19
  return (%5)

Is this expected? With PyTorch 1.5, I don’t see any fusion group.

1 Like

Found the description of the change in the release note: https://github.com/pytorch/pytorch/releases/tag/v1.5.0.

But even with torch._C._jit_set_profiling_mode(True), I don’t see any fusion group in the output.

$ cat small.py
import torch
torch._C._jit_set_profiling_mode(True)

def f(x,y): return x + y * 3

print(torch.jit.script(f).graph_for(torch.rand(2, 2, device='cuda'), torch.rand(2, 2, device='cuda')))

$ python small.py
graph(%x.1 : Tensor,
      %y.1 : Tensor):
  %2 : int = prim::Constant[value=3]() # small.py:4:27
  %3 : int = prim::Constant[value=1]()
  %4 : Tensor = prim::profile(%y.1)
  %5 : Tensor = aten::mul(%4, %2) # small.py:4:23
  %6 : Tensor = prim::profile(%x.1)
  %7 : Tensor = prim::profile(%5)
  %8 : Tensor = aten::add(%6, %7, %3) # small.py:4:19
  %9 : Tensor = prim::profile(%8)
   = prim::profile()
  return (%9)

Could you try to run the code in the latest nightly build?
If you can still reproduce this issue, feel free to create an issue here including the reproducible code snippets.