Hi all,
It seems that .graph_for
doesn’t seem to run optimizations anymore with PyTorch 1.5.
For example, from the following simple code, graph_for
dumps different graphs between PyTorch 1.4 and PyTorch 1.5.
$ cat small.py
import torch
def f(x,y): return x + y * 3
print(torch.jit.script(f).graph_for(torch.rand(2, 2, device='cuda'), torch.rand(2, 2, device='cuda')))
PyTorch 1.4:
$ python small.py
graph(%x.1 : Float(*, *),
%y.1 : Float(*, *)):
%6 : Float(*, *) = prim::FusionGroup_0(%x.1, %y.1)
return (%6)
with prim::FusionGroup_0 = graph(%0 : Float(*, *),
%4 : Float(*, *)):
%2 : int = prim::Constant[value=1]()
%5 : int = prim::Constant[value=3]() # small.py:3:27
%6 : Float(*, *) = aten::mul(%4, %5) # small.py:3:23
%3 : Float(*, *) = aten::add(%0, %6, %2) # small.py:3:19
return (%3)
PyTorch 1.5:
$ python small.py
graph(%x.1 : Tensor,
%y.1 : Tensor):
%3 : int = prim::Constant[value=3]() # small.py:3:27
%2 : int = prim::Constant[value=1]()
%4 : Tensor = aten::mul(%y.1, %3) # small.py:3:23
%5 : Tensor = aten::add(%x.1, %4, %2) # small.py:3:19
return (%5)
Is this expected? With PyTorch 1.5, I don’t see any fusion group.