Add auxiliary branch to pretrained ResNet

Problem

Inspired by the architecture of Inception (where there are auxiliary branches in the middle of the entire network), I would like to try to add auxiliary branches to each block in ResNet, namely model.layer1 through model.layer4 in model = resnet18(pretrianed=True).

As much as I could easily replace model.fc with the layer I defined below and achieves the multi-task learning, it seems that there is no easy way to replace model.layer1model.layer4 without changing the original model definition.

class MultiTaskBranch(nn.Module):
    def __init__(self, input_size, num_classes):
        super(MultiTaskBranch, self).__init__()

        self.multi_task_classifier = nn.Sequential(nn.Linear(input_size, 100),
                                                   nn.BatchNorm1d(num_features=100),
                                                   nn.LeakyReLU(inplace=True),
                                                   nn.Linear(100, num_classes))
        self.classifier = nn.Sequential(nn.Linear(input_size, 100),
                                        nn.BatchNorm1d(num_features=100),
                                        nn.LeakyReLU(inplace=True),
                                        nn.Linear(100, num_classes))
    def forward(self, x):
        return self.classifier(x), self.multi_task_classifier(x)

So my question is

  • Is it possible to meet this purpose without changing the original architecture (maybe something like hook?).
  • If it is not possible and I decide to change the original architecture, is it possible to still use the pretrained weights?

Could someone help me, thank you in advance!

Hello, have you figured this out?
I am dealing with the exact same problem where I want to modify hidden layers of a pretrained network s.t. they output auxiliary loss values.
Any help will be very much appreciated :slight_smile:

What do you mean with:

I want to modify hidden layers of a pretrained network s.t. they output auxiliary loss values.

Do you want to change the input/output dimensions of the network? Or do you want to add layers to the model?

For the first case, it might make your network incompatible with your pretrained weights due to size mismatch between the weights.

For the second case, I would recommend for you to code the network as a class that you can modify as you want (if you are using the models from the zoo, you could just copy the code from there!) and then just loading the compatible weights like so:

pretrained_dict = torch.load('path_to_weights.pth')
            
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict)}

# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 

# 3. load the new state dict
model.load_state_dict(model_dict)

Hello!

First of all, I appreciate your trying to help me out :slight_smile:

But I think you haven’t read the OP’s question carefully, which was “adding auxiliary branches to pretrained network”.

The problem with this is that pretrained models already have:

  1. fixed architecture, which includes fixed size of weights as you’ve pointed out
  2. fixed input arguments for forward()

Now, what I mean by adding auxiliary branches (and I’m quite confident OP also meant the same thing) is something like this:

A. Original module : out = module(input)
B. Modified module: out, aux_out = module(input)

I already know how to modify the intermediate module like B.

However, the problem is with 2. fixed input arguments for forward():

Even if I change the modified module to output a tuple (out, aux_out), the following module will not be able to accept that as an input.

The workaround may be using forward hooks to directly access the intermediate outputs of hidden layers, but I’m not so sure how the gradient will be affected in this case.

This might be too late to help the OP, but for anyone who meets a similar problem:

“Taking an advantage of forward hook paradigm in PyTorch [30], torchdistill supports introducing such auxiliary modules without altering the original implementations
of the models.”
torchdistill: A Modular, Configuration-Driven Framework for Knowledge Distillation, Yoshitomo (2020 RRPR)

So it seems that torchdistill already has this functionality.

Even for those who are not directly dealing with knowledge distillation (both OP and me), it’s good to know that there is one solid implementation. This would be a nice starting point :slight_smile:

3 Likes

@111482 Can you please elaborate how did you go about implementing the itermediate module as in B?

Hi,

You must use Pytorch’s forward hook to fetch intermediate output from specific layers. Torchdistill provides an easy-to-use wrapper for forward / backward hooks, so you may use these. Unfortunately, I could not simply install through pip install torchdistill due to the requirements, so I just scrapped the code snippet from their repository. You can also refer to this jupyter notebook provided in their repo.

Please try the code below:

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import models
from torchvision import transforms


# codes from https://github.com/yoshitomo-matsubara/torchdistill
def get_module(root_module, module_path):
    module_names = module_path.split('.')
    module = root_module
    for module_name in module_names:
        if not hasattr(module, module_name):
            if isinstance(module, (DataParallel, DistributedDataParallel)):
                module = module.module
                if not hasattr(module, module_name):
                    if isinstance(module, Sequential) and module_name.lstrip('-').isnumeric():
                        module = module[int(module_name)]
                    else:
                        logger.info('`{}` of `{}` could not be reached in `{}`'.format(module_name, module_path,
                                                                                       type(root_module).__name__))
                else:
                    module = getattr(module, module_name)
            elif isinstance(module, (Sequential, ModuleList)) and module_name.lstrip('-').isnumeric():
                module = module[int(module_name)]
            else:
                logger.info('`{}` of `{}` could not be reached in `{}`'.format(module_name, module_path,
                                                                               type(root_module).__name__))
                return None
        else:
            module = getattr(module, module_name)
    return module


def get_hierarchized_dict(module_paths):
    children_dict = OrderedDict()
    for module_path in module_paths:
        elements = module_path.split('.')
        if elements[0] not in children_dict and len(elements) == 1:
            children_dict[elements[0]] = module_path
            continue
        elif elements[0] not in children_dict:
            children_dict[elements[0]] = list()
        children_dict[elements[0]].append('.'.join(elements[1:]))

    for key in children_dict.keys():
        value = children_dict[key]
        if isinstance(value, list) and len(value) > 1:
            children_dict[key] = get_hierarchized_dict(value)
    return children_dict
        
def get_device_index(data):
    if isinstance(data, torch.Tensor):
        device = data.device
        return 'cpu' if device.type == 'cpu' else device.index
    elif isinstance(data, abc.Mapping):
        for key, data in data.items():
            result = get_device_index(data)
            if result is not None:
                return result
    elif isinstance(data, tuple):
        for d in data:
            result = get_device_index(d)
            if result is not None:
                return result
    elif isinstance(data, abc.Sequence) and not isinstance(data, string_classes):
        for d in data:
            result = get_device_index(d)
            if result is not None:
                return result
    return None


def register_forward_hook_with_dict(module, module_path, requires_input, requires_output, io_dict):
    io_dict[module_path] = dict()

    def forward_hook4input(self, func_input, func_output):
        if isinstance(func_input, tuple) and len(func_input) == 1:
            func_input = func_input[0]

        device_index = get_device_index(func_output)
        sub_io_dict = io_dict[module_path]
        if 'input' not in sub_io_dict:
            sub_io_dict['input'] = dict()
        sub_io_dict['input'][device_index] = func_input

    def forward_hook4output(self, func_input, func_output):
        if isinstance(func_output, tuple) and len(func_output) == 1:
            func_output = func_output[0]

        device_index = get_device_index(func_output)
        sub_io_dict = io_dict[module_path]
        if 'output' not in sub_io_dict:
            sub_io_dict['output'] = dict()
        sub_io_dict['output'][device_index] = func_output

    def forward_hook4io(self, func_input, func_output):
        if isinstance(func_input, tuple) and len(func_input) == 1:
            func_input = func_input[0]
        if isinstance(func_output, tuple) and len(func_output) == 1:
            func_output = func_output[0]

        device_index = get_device_index(func_output)
        sub_io_dict = io_dict[module_path]
        if 'input' not in sub_io_dict:
            sub_io_dict['input'] = dict()

        if 'output' not in sub_io_dict:
            sub_io_dict['output'] = dict()

        sub_io_dict['input'][device_index] = func_input
        sub_io_dict['output'][device_index] = func_output

    if requires_input and not requires_output:
        return module.register_forward_hook(forward_hook4input)
    elif not requires_input and requires_output:
        return module.register_forward_hook(forward_hook4output)
    elif requires_input and requires_output:
        return module.register_forward_hook(forward_hook4io)
    raise ValueError('Either requires_input or requires_output should be True')

    

class ForwardHookManager(object):
    def __init__(self, target_device):
        self.target_device = torch.device(target_device) if isinstance(target_device, str) else target_device
        self.uses_cuda = self.target_device.type == 'cuda'
        self.io_dict = dict()
        self.hook_list = list()

    def add_hook(self, module, module_path, requires_input=True, requires_output=True):
        sub_module = get_module(module, module_path)
        handle = \
            register_forward_hook_with_dict(sub_module, module_path, requires_input, requires_output, self.io_dict)
        self.hook_list.append((module_path, handle))

    def pop_io_dict(self):
        gathered_io_dict = dict()
        for module_path, module_io_dict in self.io_dict.items():
            gathered_io_dict[module_path] = dict()
            for io_type in list(module_io_dict.keys()):
                sub_dict = module_io_dict.pop(io_type)
                values = [sub_dict[key] for key in sorted(sub_dict.keys())]
                gathered_obj = gather(values, self.target_device) if self.uses_cuda and len(values) > 1 else values[-1]
                gathered_io_dict[module_path][io_type] = gathered_obj
        return gathered_io_dict

    def pop_io_dict_from_device(self, device):
        device_io_dict = dict()
        device_key = device.index if device.type == 'cuda' else device.type
        for module_path, module_io_dict in self.io_dict.items():
            device_io_dict[module_path] = dict()
            for io_type in list(module_io_dict.keys()):
                sub_dict = module_io_dict[io_type]
                device_io_dict[module_path][io_type] = sub_dict.pop(device_key)
        return device_io_dict

    def change_target_device(self, target_device):
        if self.target_device.type != target_device.type:
            for sub_dict in self.io_dict.values():
                sub_dict.clear()
        self.target_device = target_device

    def clear(self):
        self.io_dict.clear()
        for _, handle in self.hook_list:
            handle.remove()
        self.hook_list.clear()


# prepare dataset: We'll use CIFAR-10 as an example
trf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])
dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=trf)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
loader = iter(dataloader)


# device
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")


# fetch model: let's use ResNet18
model = models.resnet18(pretrained=False).to(device)
model.eval()


# attach forward hooks
forward_hook_manager = ForwardHookManager(device)
forward_hook_manager.add_hook(model, 'layer1.0.conv1', requires_input=True, requires_output=False)
forward_hook_manager.add_hook(model, 'layer2.0.conv1', requires_input=True, requires_output=False)
forward_hook_manager.add_hook(model, 'layer3.0.conv1', requires_input=True, requires_output=False)
forward_hook_manager.add_hook(model, 'layer4.0.conv1', requires_input=True, requires_output=False)


# forward pass through ResNet18
x, l = next(loader)
with torch.no_grad():
    y = model(x.to(device))

    
# fetch auxiliary feature maps
io_dict = forward_hook_manager.pop_io_dict()
print(io_dict.keys(), '\n')


# visualize
N = len(io_dict)
fig,ax = plt.subplots(1, N+1, figsize=(5*(N+1), 5))
ax[0].imshow(x.cpu()[0].permute(1,2,0))
ax[0].set_title("Original Image", fontsize=16)
for i, k in enumerate(io_dict.keys()):
    feature_maps = io_dict[k]['input']
    
    # average feature maps channel-wise for visualization
    feature_maps_avg = feature_maps.squeeze().mean(dim=0, keepdims=True)
    feature_maps_avg = feature_maps_avg.cpu().permute(1,2,0)
    ax[i+1].imshow(feature_maps_avg)
    ax[i+1].set_title(k, fontsize=16)

You will get something like below:

Hope this helps!