Composition of Multiple Networks


I would like to build an architecture that looks like Inception network based on ResNet, where there are some auxiliary branches in the middle (the one circled in blue rectangle).

However, I would like to see the optimal position to place my auxiliary branch. Since there are 4 blocks in ResNet, there are two options.

  • Build four (largely) same networks where the only difference is the auxiliary branch. The only downside is that this will make me write a lot of boilerplate code.
  • Build only one network and compose ResNet and my auxiliary branch in some smart way. The good option is to use register_forward_hook and register_backward_hook.

However, there are some difficulties if I choose to go with the second option. More specifically,

  • How do I correctly preserve gradient from auxiliary branch. The following code snippet could not do this and all I could get is NoneType.
  • If the problem mentioned above is somehow solved. Could the following code work correctly?
import torch

from torch import nn, optim
from torch.nn import functional as F
from import DataLoader

from torchvision import transforms as T
from torchvision.datasets import MNIST

global y
global grad
y = None
grad = None

class MainNetwork(nn.Module):
    def __init__(self):
        super(MainNetwork, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
        self.layer2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        self.layer3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
        self.layer4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(7, 7)),
                                        nn.Flatten(start_dim=1, end_dim=-1),
                                        nn.Linear(256*7*7, 1024),
                                        nn.Linear(1024, 10))
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        output = self.classifier(x)
        return output

class AuxiliaryNetwork(nn.Module):
    def __init__(self, inchannel):
        super(AuxiliaryNetwork, self).__init__()
        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(7, 7)),
                                        nn.Flatten(start_dim=1, end_dim=-1),
                                        nn.Linear(inchannel*7*7, 1024),
                                        nn.Linear(1024, 10))
    def forward(self, x):
        output = self.classifier(x)
        return output

def forward_hook(self, input, output):
    global y
    global grad

    print("\tactivating forward hook")
    output.retain_grad = True

    score = model_aux(output)
    loss_aux = criterion(score, y)

    grad = output.grad

def backward_hook(self, grad_input, grad_output):
    global grad
    print("\tactivating backward hook")
    return (grad_input[0] + grad)
model = MainNetwork()
model_aux = AuxiliaryNetwork(inchannel=32)
opt = optim.Adam(model.parameters(), lr=1e-4)
opt_aux = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(reduction="mean")


dataset = MNIST(root=".", train=True, transform=T.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=32)

for X_train, y_train in dataloader:
    X_train = X_train.type(torch.float32)
    y_train = y_train.type(torch.int64)
    y = y_train


    score = model(X_train)
    loss = criterion(input=score, target=y_train)


Hello @MrRobot, I am also facing a similar problem to train a network with multiple different branches. Did you find any solution to your problem?

Especially, I am interested to know, how to train a learning network multiple branch network with a shared network and different types of losses (MSE, CrossEntropyLoss, and L2) as well as optimizers for different branches
@ptrblck If you have time could you please take look? It will be really helpful, I am stuck on this issue for several weeks. Thanks

1 Like