Batchnorm in multi-head CNN

Hey there,

I intend to build a densenet121-based multi-head CNN. My goal is to use the first three denseblocks as a shared set of layers and then create multiple branches using the architecture of the fourth denseblock.

My code is below:

class Densenet121_v1(nn.Module):
    def __init__(self, num_classes = [1,4,2,8,3]):
        super(Densenet121_v1,self).__init__()
        original_model = torchvision.models.densenet121(pretrained=True)
        self.num_classes=num_classes
        
        self.trunk=original_model.features[:-2]
        
        self.branch1=original_model.features[-2:]
        self.branch2=original_model.features[-2:]
        self.branch3=original_model.features[-2:]
        self.branch4=original_model.features[-2:]
        self.branch5=original_model.features[-2:]
        
        #self.features=nn.Sequential(*self.trunk,*self.branch)
        #self.features = nn.Sequential(*list(original_model.children())[:-1])
        
        self.classifier1=(nn.Linear(1024, self.num_classes[0]))
        self.classifier2=(nn.Linear(1024, self.num_classes[1]))
        self.classifier3=(nn.Linear(1024, self.num_classes[2]))
        self.classifier4=(nn.Linear(1024, self.num_classes[3]))
        self.classifier5=(nn.Linear(1024, self.num_classes[4]))
        
    def forward(self, x):
        
        # shared trunk
        trunk=self.trunk(x)
    
        #block1: 
        f1 = self.branch1(trunk)
        b1 = F.relu(f1, inplace=False)
        b1 = F.adaptive_avg_pool2d(b1, (1, 1)).view(f1.size(0), -1)
        y1 = self.classifier1(b1)
  
        #block2: 
        f2 = self.branch2(trunk)
        b2 = F.relu(f2, inplace=False)
        b2 = F.adaptive_avg_pool2d(b2, (1, 1)).view(f2.size(0), -1)
        y2 = self.classifier2(b2)

        #block3: 
        f3 = self.branch3(trunk)
        b3 = F.relu(f3, inplace=False)
        b3 = F.adaptive_avg_pool2d(b3, (1, 1)).view(f3.size(0), -1)
        y3 = self.classifier3(b3)

        #block4: 
        f4 = self.branch4(trunk) 
        b4 = F.relu(f4, inplace=False)
        b4 = F.adaptive_avg_pool2d(b4, (1, 1)).view(f4.size(0), -1)
        y4 = self.classifier4(b4)

        #block5: 
        f5 = self.branch5(trunk)
        b5 = F.relu(f5, inplace=False)
        b5 = F.adaptive_avg_pool2d(b5, (1, 1)).view(f5.size(0), -1)
        y5 = self.classifier5(b5)
 
        #order predictions
        output= torch.cat((y1,y2[:,0:3],torch.unsqueeze(y4[:,0],1),torch.unsqueeze(y2[:,3],1),
                            y4[:,1:3],torch.unsqueeze(y5[:,0],1),
                            torch.unsqueeze(y3[:,0],1),torch.unsqueeze(y4[:,3],1),
                            torch.unsqueeze(y3[:,1],1),y4[:,4:7],
                            torch.unsqueeze(y5[:,1],1),torch.unsqueeze(y4[:,7],1),
                            torch.unsqueeze(y5[:,2],1)),1)

        
        return output
    
model=Densenet121_v1()
out=model(torch.rand(16,3,224,224))

When I inspect the computational graph using tensorboard I see that some tensors resulting from batchnorm layers in branch1 are being fed into the remaining branches. This does not happen with the remaining branches:

How can I avoid this? The heads should be completely independent of each other.

Probably you can try to deepcopy the branches and then resetting the parameters of the branches to make them independent?
Current code has all the branches sharing the same weights.

Try resetting the parameters as shown here: How to re-set alll parameters in a network - #12 by Brando_Miranda

Thanks for the quick reply.

I’m not sure how to do that. I was trying to do the following:

def weight_reset(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

    
class Densenet121_conv(nn.Module):
    def __init__(self, num_classes = [1,4,2,8,3]):
        super(Densenet121_conv,self).__init__()
        original_model = torchvision.models.densenet121(pretrained=True)
        self.num_classes=num_classes
        
        self.trunk=original_model.features[:-2]
        
        self.branch1=original_model.features[-2:]
        self.branch2=original_model.features[-2:]
        self.branch3=original_model.features[-2:]
        self.branch4=original_model.features[-2:]
        self.branch5=original_model.features[-2:]
        
        self.classifier1=(nn.Linear(1024, self.num_classes[0]))
        self.classifier2=(nn.Linear(1024, self.num_classes[1]))
        self.classifier3=(nn.Linear(1024, self.num_classes[2]))
        self.classifier4=(nn.Linear(1024, self.num_classes[3]))
        self.classifier5=(nn.Linear(1024, self.num_classes[4]))
        
    def forward(self, x):
        
        # shared trunk
        trunk=self.trunk(x)
    
        #block1: 
        f1 = self.branch1(trunk)
        b1 = F.relu(f1, inplace=False)
        b1 = F.adaptive_avg_pool2d(b1, (1, 1)).view(f1.size(0), -1)
        y1 = self.classifier1(b1)
        
        # copy and re-set
        f1_copy=copy.deepcopy(f1)
        f1.apply(weight_reset)
  
        #block2: 
        f2 = self.branch2(trunk)
        b2 = F.relu(f2, inplace=False)
        b2 = F.adaptive_avg_pool2d(b2, (1, 1)).view(f2.size(0), -1)
        y2 = self.classifier2(b2)

        #block3: 
        f3 = self.branch3(trunk)
        b3 = F.relu(f3, inplace=False)
        b3 = F.adaptive_avg_pool2d(b3, (1, 1)).view(f3.size(0), -1)
        y3 = self.classifier3(b3)

        #block4: 
        f4 = self.branch4(trunk) 
        b4 = F.relu(f4, inplace=False)
        b4 = F.adaptive_avg_pool2d(b4, (1, 1)).view(f4.size(0), -1)
        y4 = self.classifier4(b4)

        #block5: 
        f5 = self.branch5(trunk)
        b5 = F.relu(f5, inplace=False)
        b5 = F.adaptive_avg_pool2d(b5, (1, 1)).view(f5.size(0), -1)
        y5 = self.classifier5(b5)
 
        output= torch.cat((y1,y2[:,0:3],torch.unsqueeze(y4[:,0],1),torch.unsqueeze(y2[:,3],1),
                            y4[:,1:3],torch.unsqueeze(y5[:,0],1),
                            torch.unsqueeze(y3[:,0],1),torch.unsqueeze(y4[:,3],1),
                            torch.unsqueeze(y3[:,1],1),y4[:,4:7],
                            torch.unsqueeze(y5[:,1],1),torch.unsqueeze(y4[:,7],1),
                            torch.unsqueeze(y5[:,2],1)),1)

        #restore the branch weights after the forward pass
        f1=f1_copy
        
        return output
    

but got the error Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

Either way, would something like this work during backprop? The idea is that during backprop the common trunk is updated by all losses but the subbranches are only updated by the losses of their outputs.

Further help would be really appreciated.
Thanks!

Yes this would work. For example, the following work seems to follow this paradigm.
ensemblenet: end-to-end optimization of multi-headed models

This might help.

Still need to test it with my data set but looking at the graph in tensorboard looks like its working like a charm:

Thank you so much!
Sofia

Hi again,

When you say that without the reset all the branches are sharing the same weights, do you mean that they all start from the same weights? But as training progresses (and backprop happens), they will end up getting different weights, right?

Thank you!