Help with simple jit scripting

Hi all,
I’ve been trying to use jit scripting but I seem to be getting no success. I’ve now tried to create a very simple test case as shown. I’m running this on conda python 3.7, cuda 10.1, pytorch 1.5.1, CentOS 7.6 installed as the website says.

import torch                                                                                                                                                                          
                                                                                                                                                                                      
@torch.jit.script                                                                                                                                                                     
def simple_kernel(x1, y1, x2, y2):                                                                                                                                                    
    xi = torch.max(x1, x2)                                                                                                                                                            
    yi = torch.max(y1, y2)                                                                                                                                                            
    zi = xi+yi                                                                                                                                                                        
    return zi                                                                                                                                                                                                                                                                                                                                                               
                                                                                                                                                                                      
x1, y1, x2, y2 = torch.randn(4, 10_000_000, device='cuda')                                                                                                                            
zz = simple_kernel(x1, y1, x2, y2)  
simple_kernel.graph_for(x1, y1, x2, y2)                                                                                                                                               
print(simple_kernel.graph)                                                                                                                                                            

When I run this, I see the following:

(base) [jlquinn test]$ PYTORCH_FUSION_DEBUG=1 python jittst1.py
graph(%x1.1 : Tensor,
      %y1.1 : Tensor,
      %x2.1 : Tensor,
      %y2.1 : Tensor):
  %12 : int = prim::Constant[value=1]()
  %xi.1 : Tensor = aten::max(%x1.1, %x2.1) # jittst1.py:6:9
  %yi.1 : Tensor = aten::max(%y1.1, %y2.1) # jittst1.py:7:9
  %zi.1 : Tensor = aten::add(%xi.1, %yi.1, %12) # jittst1.py:8:9
  return (%zi.1)

From what I’ve read, this should be a simple case for pytorch to fuse, consisting of simple pointwise operations, but it appears that there is no fusion happening. Can anyone enlighten me as to what I’m missing?

Thanks
Jerry

Hi,
In pytorch 1.5 it is necessary to enable to profile guided optimization to visualize the fusion group. Thus

torch._C._jit_set_profiling_executor(False)
@torch.jit.script                                                                                                                                                                     
def simple_kernel(x1, y1, x2, y2):                                                                                                                                                    
    xi = torch.max(x1, x2)                                                                                                                                                            
    yi = torch.max(y1, y2)                                                                                                                                                            
    zi = xi+yi                                                                                                                                                                        
    return zi                                                                                                                                                                                                                                                                                                                                                               
zz = simple_kernel(x1, y1, x2, y2)
print(simple_kernel.graph_for(x1,y1,x2,y2))

Will give the optimal answer

Thanks for taking the time to respond. Is this documented anywhere yet? And more importantly, when I run a torch.jit.script function, it will optimize after it’s been run a few times, right?