ModuleDict to route observations to different heads

I would like to route different observations within a batch to a different pre-processing head (small NN), then all of them go through the same backbone-CNN. I thought I could achieve this like so using ModuleDict:

class NNSandwich(nn.Module):
    def __init__(self, num_heads):
        super(NNSandwich, self).__init__()
        
        # Pre-NN
        self.pre_choices = nn.ModuleDict()
        for i in range(num_heads):
                self.pre_choices["pre_{}".format(i)] = nn.Sequential(
                    nn.Conv2d(3, 32, 3),
                    nn.ReLU(),
                    nn.Conv2d(32, 3, 3),
                    nn.ReLU()
                )    
                
        # Shared-backbone
        self.shared_backbone = torchvision.models.resnet18(True)
        self.fc = nn.Linear(1000, 2)
        
    def forward(self, x, choices):
         
        x = torch.cat(
            [self.pre_choices[j](x[i].unsqueeze(0)) for i,j in enumerate(choices)],
            dim=0
        )
        x = self.shared_backbone(x)
        x = self.fc(x)
        return x

However, if seems that choice can’t be a list (different for each observation) but a string that is thus the same for each observation. Thus I use list comprehension and torch.cat() to pass through the respective heads.

However, this method doesn’t seem to update the weights properly:

import torch
import torch.nn as nn
import torchvision
import random

HEADS = 5
BATCH = 4

# NNSandwich
sandwich_nn = NNSandwich(num_heads=HEADS)
sandwich_nn.cuda()

# Optimiser
optimizer = torch.optim.SGD(sandwich_nn.parameters(), lr=0.5)

# Input
batch_i = torch.ones(BATCH,3,224,224).cuda()
labels_i = torch.ones(BATCH).cuda().long()
choices_i = ['pre_2', 'pre_1', 'pre_0', 'pre_4']

print(sandwich_nn.pre_choices.pre_0[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_1[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_2[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_3[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_4[0].weight[0][0])

# Forwards-Pass
out = sandwich_nn(batch_i, choices_i)
print(out)
# Backwards-Pass
loss = nn.functional.cross_entropy(out, labels_i)
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(sandwich_nn.pre_choices.pre_0[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_1[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_2[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_3[0].weight[0][0])
print(sandwich_nn.pre_choices.pre_4[0].weight[0][0])

In this case only print(sandwich_nn.pre_choices.pre_1[0].weight[0][0]) differs, it seems the other heads haven’t been updated (but pre2, pre1, pre0, pre4 should be)

Your code should work and it seems the gradients are just “killed” by the ReLUs in pre_choices.
I.e. if you remove both nn.ReLU() non-linearities, you’ll see valid updated, as they might create zero gradients for some iterations.
This doesn’t necessarily mean that your model won’t get any valid gradients in the following iterations.

Also, if you want to use an integer index instead of keys, you could use nn.ModuleList instead of nn.ModuleDict.

1 Like

Thank you for the reply Patrick!

I’m still a bit confused to understand however how my gradient will be calculated w.r.t the observations in the batch with the for-loop.

For example if my nn.Module has ModuleDict(Block1A, Block1B, block1C) -> Block2 -> Block3 -> Out

And in my input-batch I have say (all) observations corresponding to Block1A then wouldn’t the following produce gradients:

One-by-one:
x = torch.cat([self.moduledict['Block1A'](x[i].unsqueeze(0)) for i in range(len(x))])

Together:
x = self.moduledict['Block1A'](x)

Just to illustrate with a code-example:

import torch
torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.pre_heads = nn.ModuleDict()
        for head_name in ['a','b','c']:
                self.pre_heads[head_name] = nn.Sequential(
                    nn.Conv2d(16,16, kernel_size=1),
                    nn.ReLU()
                )
        self.out = nn.Sequential(
            Flatten(),
            nn.Linear(16*4*4,1)
        )
        
    def forward(self, x, head_name):           
        x = torch.cat(
            [self.pre_heads[j](x[i].unsqueeze(0)) for i, j in enumerate(head_name)],
            dim=0
        )
        x = torch.squeeze(torch.sigmoid(self.out(x)))
        return x

mymodel = Model()
sample_input = torch.ones(4,16,4,4)
output = mymodel(sample_input, head_name=['a', 'a', 'a', 'a'])
loss = F.binary_cross_entropy(output, torch.Tensor([1,1,1,1]))
loss.backward()
mymodel.pre_heads['a'][0].weight.grad.sum()
# tensor(2.3628)

v.s.

class Model(nn.Module):        
    # Only changed Forward
    def forward(self, x):           
        x = self.pre_heads['a'](x)
        x = torch.squeeze(torch.sigmoid(self.out(x)))
        return x

mymodel = Model()
sample_input = torch.ones(4,16,4,4)
output = mymodel(sample_input])
loss = F.binary_cross_entropy(output, torch.Tensor([1,1,1,1]))
loss.backward()
mymodel.pre_heads['a'][0].weight.grad.sum()
# tensor(-0.9457)

Both approaches should yield the same result and I assume your models might have used different layer parameters.
This code snippet shows that the different forward implementations create the same gradients up the floating point precision:

import torch
torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.pre_heads = nn.ModuleDict()
        for head_name in ['a','b','c']:
                self.pre_heads[head_name] = nn.Sequential(
                    nn.Conv2d(16,16, kernel_size=1),
                    nn.ReLU()
                )
        self.out = nn.Sequential(
            Flatten(),
            nn.Linear(16*4*4,1)
        )
    
    def forward(self, x, head_name, use_heads=True):
        if use_heads:
            x = torch.cat(
                [self.pre_heads[j](x[i].unsqueeze(0)) for i, j in enumerate(head_name)],
                dim=0
            )
            x = torch.squeeze(torch.sigmoid(self.out(x)))
            return x
        else:
            x = self.pre_heads['a'](x)
            x = torch.squeeze(torch.sigmoid(self.out(x)))
            return x

mymodel = Model()
sample_input = torch.ones(4,16,4,4)

# 
output = mymodel(sample_input, head_name=['a', 'a', 'a', 'a'])
loss = F.binary_cross_entropy(output, torch.Tensor([1,1,1,1]))
loss.backward()
grad0 = mymodel.pre_heads['a'][0].weight.grad.sum().clone()

#
mymodel.zero_grad()
output = mymodel(sample_input, head_name=None, use_heads=False)
loss = F.binary_cross_entropy(output, torch.Tensor([1,1,1,1]))
loss.backward()
grad1 = mymodel.pre_heads['a'][0].weight.grad.sum().clone()

print((grad0 - grad1).abs())
> tensor(1.1921e-07)
1 Like

Thank you for testing this out! It makes sense to me now (I guess I did have the two models differently initialised despite setting seed)