I have currently started learning torchscript and tried to visualize the optimized graph but I am unsuccessful. The following is my code
@torch.jit.script
def cell_end(ingate, forgetgate, cellgate, outgate, cx):
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy
For the above function i got the following graph
graph(%ingate.1 : Tensor,
%forgetgate.1 : Tensor,
%cellgate.1 : Tensor,
%outgate.1 : Tensor,
%cx.1 : Tensor):
%19 : int = prim::Constant[value=1]()
%ingate.3 : Tensor = aten::sigmoid(%ingate.1) # <ipython-input-2-8da29633480c>:3:13
%forgetgate.3 : Tensor = aten::sigmoid(%forgetgate.1) # <ipython-input-2-8da29633480c>:4:17
%cellgate.3 : Tensor = aten::tanh(%cellgate.1) # <ipython-input-2-8da29633480c>:5:15
%outgate.3 : Tensor = aten::sigmoid(%outgate.1) # <ipython-input-2-8da29633480c>:6:14
%15 : Tensor = aten::mul(%forgetgate.3, %cx.1) # <ipython-input-2-8da29633480c>:8:10
%18 : Tensor = aten::mul(%ingate.3, %cellgate.3) # <ipython-input-2-8da29633480c>:8:30
%cy.1 : Tensor = aten::add(%15, %18, %19) # <ipython-input-2-8da29633480c>:8:10
%23 : Tensor = aten::tanh(%cy.1) # <ipython-input-2-8da29633480c>:9:19
%hy.1 : Tensor = aten::mul(%outgate.3, %23) # <ipython-input-2-8da29633480c>:9:9
%27 : (Tensor, Tensor) = prim::TupleConstruct(%hy.1, %cy.1)
return (%27)
Running with an input
inp = torch.randn(5, 10, 4)
cell_end(*inp)
Even after running the input over the graph I got the same graph
graph(%ingate.1 : Tensor,
%forgetgate.1 : Tensor,
%cellgate.1 : Tensor,
%outgate.1 : Tensor,
%cx.1 : Tensor):
%5 : int = prim::Constant[value=1]()
%ingate.3 : Tensor = aten::sigmoid(%ingate.1) # <ipython-input-2-8da29633480c>:3:13
%forgetgate.3 : Tensor = aten::sigmoid(%forgetgate.1) # <ipython-input-2-8da29633480c>:4:17
%cellgate.3 : Tensor = aten::tanh(%cellgate.1) # <ipython-input-2-8da29633480c>:5:15
%outgate.3 : Tensor = aten::sigmoid(%outgate.1) # <ipython-input-2-8da29633480c>:6:14
%10 : Tensor = aten::mul(%forgetgate.3, %cx.1) # <ipython-input-2-8da29633480c>:8:10
%11 : Tensor = aten::mul(%ingate.3, %cellgate.3) # <ipython-input-2-8da29633480c>:8:30
%cy.1 : Tensor = aten::add(%10, %11, %5) # <ipython-input-2-8da29633480c>:8:10
%13 : Tensor = aten::tanh(%cy.1) # <ipython-input-2-8da29633480c>:9:19
%hy.1 : Tensor = aten::mul(%outgate.3, %13) # <ipython-input-2-8da29633480c>:9:9
%15 : (Tensor, Tensor) = prim::TupleConstruct(%hy.1, %cy.1)
return (%15)
I don’t know the reason behind it, Anyone please help me with this