Problem
I would like to build an architecture that looks like Inception network based on ResNet, where there are some auxiliary branches in the middle (the one circled in blue rectangle).
However, I would like to see the optimal position to place my auxiliary branch. Since there are 4 blocks in ResNet, there are two options.
- Build four (largely) same networks where the only difference is the auxiliary branch. The only downside is that this will make me write a lot of boilerplate code.
- Build only one network and compose ResNet and my auxiliary branch in some smart way. The good option is to use
register_forward_hook
andregister_backward_hook
.
However, there are some difficulties if I choose to go with the second option. More specifically,
- How do I correctly preserve gradient from auxiliary branch. The following code snippet could not do this and all I could get is
NoneType
. - If the problem mentioned above is somehow solved. Could the following code work correctly?
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import MNIST
global y
global grad
y = None
grad = None
class MainNetwork(nn.Module):
def __init__(self):
super(MainNetwork, self).__init__()
self.layer1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True))
self.layer2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True))
self.layer3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True))
self.layer4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True))
self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(7, 7)),
nn.Flatten(start_dim=1, end_dim=-1),
nn.Linear(256*7*7, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 10))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
output = self.classifier(x)
return output
class AuxiliaryNetwork(nn.Module):
def __init__(self, inchannel):
super(AuxiliaryNetwork, self).__init__()
self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(7, 7)),
nn.Flatten(start_dim=1, end_dim=-1),
nn.Linear(inchannel*7*7, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 10))
def forward(self, x):
output = self.classifier(x)
return output
def forward_hook(self, input, output):
global y
global grad
print("\tactivating forward hook")
opt_aux.zero_grad()
output.retain_grad = True
score = model_aux(output)
loss_aux = criterion(score, y)
loss_aux.backward(retain_graph=True)
opt_aux.step()
grad = output.grad
def backward_hook(self, grad_input, grad_output):
global grad
print("\tactivating backward hook")
return (grad_input[0] + grad)
model = MainNetwork()
model_aux = AuxiliaryNetwork(inchannel=32)
opt = optim.Adam(model.parameters(), lr=1e-4)
opt_aux = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(reduction="mean")
model.layer1.register_forward_hook(forward_hook)
model.layer1.register_backward_hook(backward_hook)
dataset = MNIST(root=".", train=True, transform=T.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=32)
for X_train, y_train in dataloader:
X_train = X_train.type(torch.float32)
y_train = y_train.type(torch.int64)
y = y_train
opt.zero_grad()
score = model(X_train)
loss = criterion(input=score, target=y_train)
loss.backward(retain_graph=True)
opt.step()