Jit trace tuple construct and unpack; can't keep track of inputs

Hi everyone,

I have the following jupyter notebook where I access the intermediate layers of the the backbone using a custom class. The problem is that, as shown in there screenshot, there is a tuple construct and unpack operation that obscures the fact that the multiplication by 2 is done to the skip connection and not the output of the backbone. This is just a toy reproduction of an issue I have with a bigger network structure. Is there a way to avoid the tuple construct and unpack operations so that it is clear where the arrows come from.

first cell:

from collections import OrderedDict
from tqdm import tqdm
import network
import utils
import os
import random
import argparse
import numpy as np

from torch.utils import data
from datasets import VOCSegmentation, Cityscapes, sTilesSegmentation
from utils import ext_transforms as et
from metrics import StreamSegMetrics

import torch
import torch.nn as nn
from utils.visualizer import Visualizer

from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
import time

class simple_backbone(nn.Module):
    def __init__(self):
        super(simple_backbone,self).__init__()
        self.conv1 = nn.Conv2d(3,3,3,padding=1)
        self.conv2 = nn.Conv2d(3,3,3,padding=1)
        self.conv3 = nn.Conv2d(3,3,3,padding=1)
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x
class simple_classifier(nn.Module):
    def __init__(self):
        super(simple_classifier,self).__init__()
        self.conv = nn.Conv2d(6,3,3,padding=1)
    def forward(self,features):
        x = self.conv(torch.cat([features['out'],features['middle']*2],axis=1))
        return x

class simple_compound_network(nn.Module):
    def __init__(self,backbone,classifier):
        super(simple_compound_network,self).__init__()
        self.backbone = backbone
        self.classifier = classifier
    def forward(self,x):
        features = self.backbone(x)
        x = self.classifier(features)
        return x
class IntermediateLayerGetter(nn.ModuleDict):
    """
    Module wrapper that returns intermediate layers from a model

    It has a strong assumption that the modules have been registered
    into the model in the same order as they are used.
    This means that one should **not** reuse the same nn.Module
    twice in the forward if you want this to work.

    Additionally, it is only able to query submodules that are directly
    assigned to the model. So if `model` is passed, `model.feature1` can
    be returned, but not `model.feature1.layer2`.

    Arguments:
        model (nn.Module): model on which we will extract the features
        return_layers (Dict[name, new_name]): a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).

    Examples::

        >>> m = torchvision.models.resnet18(pretrained=True)
        >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
        >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
        >>>     {'layer1': 'feat1', 'layer3': 'feat2'})
        >>> out = new_m(torch.rand(1, 3, 224, 224))
        >>> print([(k, v.shape) for k, v in out.items()])
        >>>     [('feat1', torch.Size([1, 64, 56, 56])),
        >>>      ('feat2', torch.Size([1, 256, 14, 14]))]
    """
    def __init__(self, model, return_layers, hrnet_flag=False):
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")

        self.hrnet_flag = hrnet_flag

        orig_return_layers = return_layers
        return_layers = {k: v for k, v in return_layers.items()}
        self.return_mapping = {k_v_pair[1]: i for i, k_v_pair in enumerate(return_layers.items())}
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layers = orig_return_layers
        


    def forward(self, x):
        out = OrderedDict()
        for name, module in self.named_children():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out

second cell:

backbone = simple_backbone()

classifier = simple_classifier()

backbone = IntermediateLayerGetter(backbone,{'conv3':'out','conv1':'middle'})

model = simple_compound_network(backbone,classifier)

model.eval()

pt_inputs = torch.randn(1, 3, 576, 520)

pt_inputs = pt_inputs.type(torch.FloatTensor)

pt_inputs = torch.autograd.Variable(pt_inputs)

pt_pred = model(pt_inputs)

traced_model = torch.jit.trace(model, pt_inputs)

traced_model.save("model_test.pt")