Hi all,
I’ve been trying to use jit scripting but I seem to be getting no success. I’ve now tried to create a very simple test case as shown. I’m running this on conda python 3.7, cuda 10.1, pytorch 1.5.1, CentOS 7.6 installed as the website says.
import torch
@torch.jit.script
def simple_kernel(x1, y1, x2, y2):
xi = torch.max(x1, x2)
yi = torch.max(y1, y2)
zi = xi+yi
return zi
x1, y1, x2, y2 = torch.randn(4, 10_000_000, device='cuda')
zz = simple_kernel(x1, y1, x2, y2)
simple_kernel.graph_for(x1, y1, x2, y2)
print(simple_kernel.graph)
When I run this, I see the following:
(base) [jlquinn test]$ PYTORCH_FUSION_DEBUG=1 python jittst1.py
graph(%x1.1 : Tensor,
%y1.1 : Tensor,
%x2.1 : Tensor,
%y2.1 : Tensor):
%12 : int = prim::Constant[value=1]()
%xi.1 : Tensor = aten::max(%x1.1, %x2.1) # jittst1.py:6:9
%yi.1 : Tensor = aten::max(%y1.1, %y2.1) # jittst1.py:7:9
%zi.1 : Tensor = aten::add(%xi.1, %yi.1, %12) # jittst1.py:8:9
return (%zi.1)
From what I’ve read, this should be a simple case for pytorch to fuse, consisting of simple pointwise operations, but it appears that there is no fusion happening. Can anyone enlighten me as to what I’m missing?
Thanks
Jerry