How to use torch.jit for graph trace?

Hi all,

I try to trace myself graph using torch.jit but i have next problem:

  1. Create simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=1)
    def forward(self, x):
        x = self.conv1(x)
        return x
  1. Create variable and trace model:
net = SimpleNet()
var = Variable(torch.randn(1, 3, 224, 224))
trace, out = jit.trace(net, var)
trace.export(list(net.state_dict().values()))

And I get next output:

Traceback (most recent call last):
  File "run.py", line 17, in <module>
    trace.export(list(net.state_dict().values()))
RuntimeError: ONNX export failed: Couldn't export C++ operator ConvForward

Graph we tried to export:
graph(%1 : Float(1, 3, 224, 224)
      %2 : Float(10, 3, 1, 1)
      %3 : Float(10)) {
  %5 : Float(1, 10, 224, 224), %6 : Handle = CppOp[ConvForward](%1, %2, %3), uses = [[%0.i0], []];
  return (%5);
}

And it’s behavior has reproduced for any networks.
How can I fix it? Thanks in advance!

Hi,

In PyTorch v1.5.1, you can use trace.code. Here is an example: