I’m trying to migrate pytorch model into our own c++ framework. Currently I’m using torch.jit.trace to get the graph of the model, which is convenient, I am able to get most ops through the node name like ‘aten::add’, ‘aten::conv2d’, but for the following model:
class A(torch.nn.Module):
def init(self):
super(A,self).init()
self.layer1 = torch.nn.Linear(10,20)
self.layer2 = torch.nn.Linear(20,10)def forward(self,xx): # xx=torch.randn(50,3,10) for i in range(xx.shape[0]): x = xx[i] x = self.layer1(x) x = self.layer2(x) return x
I have trouble finding if the model is using the for loop, and which node(s) is within the for loop. Is there a way to find out the loop part?