I would like to route different observations within a batch to a different pre-processing head (small NN), then all of them go through the same backbone-CNN. I thought I could achieve this like so using ModuleDict:
class NNSandwich(nn.Module):
def __init__(self, num_heads):
super(NNSandwich, self).__init__()
# Pre-NN
self.pre_choices = nn.ModuleDict()
for i in range(num_heads):
self.pre_choices["pre_{}".format(i)] = nn.Sequential(
nn.Conv2d(3, 32, 3),
nn.ReLU(),
nn.Conv2d(32, 3, 3),
nn.ReLU()
)
# Shared-backbone
self.shared_backbone = torchvision.models.resnet18(True)
self.fc = nn.Linear(1000, 2)
def forward(self, x, choices):
x = torch.cat(
[self.pre_choices[j](x[i].unsqueeze(0)) for i,j in enumerate(choices)],
dim=0
)
x = self.shared_backbone(x)
x = self.fc(x)
return x
However, if seems that choice can’t be a list (different for each observation) but a string that is thus the same for each observation. Thus I use list comprehension and torch.cat() to pass through the respective heads.
However, this method doesn’t seem to update the weights properly:
import torch
import torch.nn as nn
import torchvision
import random
HEADS = 5
BATCH = 4
# NNSandwich
sandwich_nn = NNSandwich(num_heads=HEADS)
sandwich_nn.cuda()
# Optimiser
optimizer = torch.optim.SGD(sandwich_nn.parameters(), lr=0.5)
# Input
batch_i = torch.ones(BATCH,3,224,224).cuda()
labels_i = torch.ones(BATCH).cuda().long()
choices_i = ['pre_2', 'pre_1', 'pre_0', 'pre_4']
print(sandwich_nn.pre_choices.pre_0[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_1[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_2[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_3[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_4[0].weight[0][0])
# Forwards-Pass
out = sandwich_nn(batch_i, choices_i)
print(out)
# Backwards-Pass
loss = nn.functional.cross_entropy(out, labels_i)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(sandwich_nn.pre_choices.pre_0[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_1[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_2[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_3[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_4[0].weight[0][0])
In this case only print(sandwich_nn.pre_choices.pre_1[0].weight[0][0])
differs, it seems the other heads haven’t been updated (but pre2, pre1, pre0, pre4 should be)