[pytorch-1.0rc1] using torch.jit.trace() to trace a network that takes 2 input

I am trying to trace using torch.jit.trace() a network that takes 2 tensors (z and x) as input and produces one tensor as output (y).

I can successfully trace models with 1 input and 1 output as follows:

  sample_input = torch.rand(1,3,256,256)
  traced_module_feature = torch.jit.trace(model, sample_input)

However, I have no clue on how to trace models that take 2 inputs. I tried the following which produces error as given below:

z = torch.rand(1,3,127,127)
x = torch.rand(1,3,256,256)
traced_module_rpn = torch.jit.trace(model, { 'zf':z, 'xf':x } )

and the output is

Traceback (most recent call last):
  File "export_to_cpp.py", line 57, in <module>
  File "export_to_cpp.py", line 52, in main
    traced_module_rpn = torch.jit.trace(model.rpn_model, { 'zf':z, 'xf':x } )
  File "/home/tlm/anaconda3/envs/svt2_pyth1/lib/python3.7/site-packages/torch/jit/__init__.py", line 565, in trace
    module._create_method_from_trace('forward', func, example_inputs)
RuntimeError: Only tensors and (possibly nested) tuples of tensors are supported as inputs or outputs of traced functions (toIValue at /opt/conda/conda-bld/pytorch-nightly_1538562647654/work/torch/csrc/jit/pybind_utils.h:74)
frame #0: <unknown function> + 0x3fe53f (0x7f98c5fb353f in /home/tlm/anaconda3/envs/svt2_pyth1/lib/python3.7/site-packages/torch/_C.cpython-37m-x86_64-linux-gnu.so)
frame #1: <unknown function> + 0x463a2b (0x7f98c6018a2b in /home/tlm/anaconda3/envs/svt2_pyth1/lib/python3.7/site-packages/torch/_C.cpython-37m-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x1a665d (0x7f98c5d5b65d in /home/tlm/anaconda3/envs/svt2_pyth1/lib/python3.7/site-packages/torch/_C.cpython-37m-x86_64-linux-gnu.so)
<omitting python frames>
frame #19: __libc_start_main + 0xf0 (0x7f98da43b830 in /lib/x86_64-linux-gnu/libc.so.6)

I am using pytorch-1.0-rc1 with cuda-9.0 and python 3.7.

Thank you


and, this is the network that I am trying to trace:

import torch
import torch.nn as nn
import torch.nn.functional as F

class RPN(nn.Module):
    def __init__(self):
        super(RPN, self).__init__()

    def forward(self, z_f, x_f):
        raise NotImplementedError

def conv2d_group(x, kernel):
    batch = kernel.size()[0]
    pk = kernel.view(-1, x.size()[1], kernel.size()[2], kernel.size()[3])
    px = x.view(1, -1, x.size()[2], x.size()[3])
    po = F.conv2d(px, pk, groups=batch)
    po = po.view(batch, -1, po.size()[2], po.size()[3])
    return po

class UPChannelRPN(RPN):
    def __init__(self, anchor_num=5, feature_in=256, feature_out=256):
        super(UPChannelRPN, self).__init__()

        self.anchor_num = anchor_num
        self.feature_in = feature_in
        self.feature_out = feature_out

        self.cls_output = 2 * self.anchor_num
        self.loc_output = 4 * self.anchor_num

        self.template_cls_conv = nn.Conv2d(self.feature_in, self.feature_out * self.cls_output, kernel_size=3)
        self.template_loc_conv = nn.Conv2d(self.feature_in, self.feature_out * self.loc_output, kernel_size=3)

        self.search_cls_conv = nn.Conv2d(self.feature_in, self.feature_out, kernel_size=3)
        self.search_loc_conv = nn.Conv2d(self.feature_in, self.feature_out, kernel_size=3)

        self.loc_adjust = nn.Conv2d(self.loc_output, self.loc_output, kernel_size=1)

    def forward(self, z_f, x_f):
        cls_kernel = self.template_cls_conv(z_f)
        loc_kernel = self.template_loc_conv(z_f)

        cls_feature = self.search_cls_conv(x_f)
        loc_feature = self.search_loc_conv(x_f)

        pred_cls = conv2d_group(cls_feature, cls_kernel)
        pred_loc = self.loc_adjust(conv2d_group(loc_feature, loc_kernel))
        return pred_cls, pred_loc

I am trying to trace UPChannelRPN(). This network does not include any conditionals (to the best of my knowledge) and therefore I think it should be possible to trace this network using torch.jit.trace

The trace works when I use:

traced_module_rpn = torch.jit.trace(model.rpn_model, [ Variable(z), Variable(x) ] )

but with this warning:

rpn.py:18: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  po = F.conv2d(px, pk, groups=batch)

where line 18: of rpn.py is

    po = F.conv2d(px, pk, groups=batch)

Any help in understanding this warning message is highly appreciated.

I think you can try tuple type

Hello, i would like to see the trace graphically through
traced_net = torch.jit.trace(net,inputs, optimize=True, check_trace=True, check_inputs=None, check_tolerance=1e-05)
make_dot_from_trace(traced_net) <—

—>AttributeError: ‘TopLevelTracedModule’ object has no attribute ‘set_graph’
If there is not graph then what ?

Any insights on how to solve this. Similar problem that needs to be solved. trying to trace the model with a RPN head with 2 inputs.

1 Like

This works

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x, y):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight

n = Net()

example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# targets
H, W = 1000, 700
num_classes = 10
image_id = torch.as_tensor(1693029438, dtype=torch.int32)
num_objects = 1200
image = torch.randn((H, W))
masks = torch.randint(0, 1, (num_objects, H, W), dtype=torch.uint8) + 255
areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
labels = torch.randint(1, num_classes, (1, num_objects), dtype=torch.int64)
iscrowd = torch.zeros(num_objects, dtype=torch.int64)

target = {}
target["image_id"] = image_id
target["boxes"] = boxes
target["labels"] = labels
target["area"] = areas
target["iscrowd"] = iscrowd
target["masks"] = masks

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, (example_forward_input, target))

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, (example_forward_input, target))

# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {'forward' : (example_forward_input, target), 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)

This was taken from here