Hi everyone,
I have the following jupyter notebook where I access the intermediate layers of the the backbone using a custom class. The problem is that, as shown in there screenshot, there is a tuple construct and unpack operation that obscures the fact that the multiplication by 2 is done to the skip connection and not the output of the backbone. This is just a toy reproduction of an issue I have with a bigger network structure. Is there a way to avoid the tuple construct and unpack operations so that it is clear where the arrows come from.
first cell:
from collections import OrderedDict
from tqdm import tqdm
import network
import utils
import os
import random
import argparse
import numpy as np
from torch.utils import data
from datasets import VOCSegmentation, Cityscapes, sTilesSegmentation
from utils import ext_transforms as et
from metrics import StreamSegMetrics
import torch
import torch.nn as nn
from utils.visualizer import Visualizer
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
import time
class simple_backbone(nn.Module):
def __init__(self):
super(simple_backbone,self).__init__()
self.conv1 = nn.Conv2d(3,3,3,padding=1)
self.conv2 = nn.Conv2d(3,3,3,padding=1)
self.conv3 = nn.Conv2d(3,3,3,padding=1)
def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x
class simple_classifier(nn.Module):
def __init__(self):
super(simple_classifier,self).__init__()
self.conv = nn.Conv2d(6,3,3,padding=1)
def forward(self,features):
x = self.conv(torch.cat([features['out'],features['middle']*2],axis=1))
return x
class simple_compound_network(nn.Module):
def __init__(self,backbone,classifier):
super(simple_compound_network,self).__init__()
self.backbone = backbone
self.classifier = classifier
def forward(self,x):
features = self.backbone(x)
x = self.classifier(features)
return x
class IntermediateLayerGetter(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model
It has a strong assumption that the modules have been registered
into the model in the same order as they are used.
This means that one should **not** reuse the same nn.Module
twice in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly
assigned to the model. So if `model` is passed, `model.feature1` can
be returned, but not `model.feature1.layer2`.
Arguments:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
Examples::
>>> m = torchvision.models.resnet18(pretrained=True)
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = new_m(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
"""
def __init__(self, model, return_layers, hrnet_flag=False):
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
self.hrnet_flag = hrnet_flag
orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()}
self.return_mapping = {k_v_pair[1]: i for i, k_v_pair in enumerate(return_layers.items())}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers:
break
super(IntermediateLayerGetter, self).__init__(layers)
self.return_layers = orig_return_layers
def forward(self, x):
out = OrderedDict()
for name, module in self.named_children():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out
second cell:
backbone = simple_backbone()
classifier = simple_classifier()
backbone = IntermediateLayerGetter(backbone,{'conv3':'out','conv1':'middle'})
model = simple_compound_network(backbone,classifier)
model.eval()
pt_inputs = torch.randn(1, 3, 576, 520)
pt_inputs = pt_inputs.type(torch.FloatTensor)
pt_inputs = torch.autograd.Variable(pt_inputs)
pt_pred = model(pt_inputs)
traced_model = torch.jit.trace(model, pt_inputs)
traced_model.save("model_test.pt")