Add auxiliary branch to pretrained ResNet


Inspired by the architecture of Inception (where there are auxiliary branches in the middle of the entire network), I would like to try to add auxiliary branches to each block in ResNet, namely model.layer1 through model.layer4 in model = resnet18(pretrianed=True).

As much as I could easily replace model.fc with the layer I defined below and achieves the multi-task learning, it seems that there is no easy way to replace model.layer1model.layer4 without changing the original model definition.

class MultiTaskBranch(nn.Module):
    def __init__(self, input_size, num_classes):
        super(MultiTaskBranch, self).__init__()

        self.multi_task_classifier = nn.Sequential(nn.Linear(input_size, 100),
                                                   nn.Linear(100, num_classes))
        self.classifier = nn.Sequential(nn.Linear(input_size, 100),
                                        nn.Linear(100, num_classes))
    def forward(self, x):
        return self.classifier(x), self.multi_task_classifier(x)

So my question is

  • Is it possible to meet this purpose without changing the original architecture (maybe something like hook?).
  • If it is not possible and I decide to change the original architecture, is it possible to still use the pretrained weights?

Could someone help me, thank you in advance!

Hello, have you figured this out?
I am dealing with the exact same problem where I want to modify hidden layers of a pretrained network s.t. they output auxiliary loss values.
Any help will be very much appreciated :slight_smile:

What do you mean with:

I want to modify hidden layers of a pretrained network s.t. they output auxiliary loss values.

Do you want to change the input/output dimensions of the network? Or do you want to add layers to the model?

For the first case, it might make your network incompatible with your pretrained weights due to size mismatch between the weights.

For the second case, I would recommend for you to code the network as a class that you can modify as you want (if you are using the models from the zoo, you could just copy the code from there!) and then just loading the compatible weights like so:

pretrained_dict = torch.load('path_to_weights.pth')
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict)}

# 2. overwrite entries in the existing state dict

# 3. load the new state dict


First of all, I appreciate your trying to help me out :slight_smile:

But I think you haven’t read the OP’s question carefully, which was “adding auxiliary branches to pretrained network”.

The problem with this is that pretrained models already have:

  1. fixed architecture, which includes fixed size of weights as you’ve pointed out
  2. fixed input arguments for forward()

Now, what I mean by adding auxiliary branches (and I’m quite confident OP also meant the same thing) is something like this:

A. Original module : out = module(input)
B. Modified module: out, aux_out = module(input)

I already know how to modify the intermediate module like B.

However, the problem is with 2. fixed input arguments for forward():

Even if I change the modified module to output a tuple (out, aux_out), the following module will not be able to accept that as an input.

The workaround may be using forward hooks to directly access the intermediate outputs of hidden layers, but I’m not so sure how the gradient will be affected in this case.

This might be too late to help the OP, but for anyone who meets a similar problem:

“Taking an advantage of forward hook paradigm in PyTorch [30], torchdistill supports introducing such auxiliary modules without altering the original implementations
of the models.”
torchdistill: A Modular, Configuration-Driven Framework for Knowledge Distillation, Yoshitomo (2020 RRPR)

So it seems that torchdistill already has this functionality.

Even for those who are not directly dealing with knowledge distillation (both OP and me), it’s good to know that there is one solid implementation. This would be a nice starting point :slight_smile: