Correct branch-CNN implementation?

Apologies if reposting is against guidelines, but I figured I would ask this one more time. I’m currently trying to implement this model:

The authors use a VGG16 BN pretrained model and add in three FC branches, which I think I’ve done in this:

class BCNN(nn.Module):
    def __init__(self):
        super(BCNN,self).__init__()
        # Load pretrained model
        vgg_model = models.vgg16_bn(pretrained=True)
        self.Conv1 = nn.Sequential(*list(vgg_model.features.children())[0:7])
        self.Conv2 = nn.Sequential(*list(vgg_model.features.children())[7:14])
        # Level-1 classifier after second conv block
        self.level_one_clf = nn.Sequential(nn.Linear(128*56*56, 256), 
                                           nn.ReLU(), 
                                           nn.BatchNorm1d(256), 
                                           nn.Dropout(0.5), 
                                           nn.Linear(256, 256), 
                                           nn.BatchNorm1d(256), 
                                           nn.Dropout(0.5), 
                                           nn.Linear(256, 2))
        self.Conv3 = nn.Sequential(*list(vgg_model.features.children())[14:24])
        # Level-2 classifier after third conv block
        self.level_two_clf = nn.Sequential(nn.Linear(256*28*28, 1024), 
                                           nn.ReLU(), 
                                           nn.BatchNorm1d(1024), 
                                           nn.Dropout(0.5), 
                                           nn.Linear(1024, 1024), 
                                           nn.BatchNorm1d(1024), 
                                           nn.Dropout(0.5), 
                                           nn.Linear(1024, 7))
        self.Conv4 = nn.Sequential(*list(vgg_model.features.children())[24:34])
        self.Conv5 = nn.Sequential(*list(vgg_model.features.children())[34:44])
        # Level-3 classifier after fifth conv block
        self.level_three_clf = nn.Sequential(nn.Linear(512*7*7, 4096), 
                                             nn.ReLU(), 
                                             nn.BatchNorm1d(4096), 
                                             nn.Dropout(0.5), 
                                             nn.Linear(4096, 4096), 
                                             nn.BatchNorm1d(4096), 
                                             nn.Dropout(0.5), 
                                             nn.Linear(4096, 25))     
    def forward(self,x):
        x = self.Conv1(x)
        x = self.Conv2(x)
        lvl_one = x.view(x.size(0), -1)
        lvl_one = self.level_one_clf(lvl_one)
        x = self.Conv3(x)
        lvl_two = x.view(x.size(0), -1)
        lvl_two = self.level_two_clf(lvl_two)
        x = self.Conv4(x)
        x = self.Conv5(x)
        lvl_three = x.view(x.size(0), -1)
        lvl_three = self.level_three_clf(lvl_three)
        return lvl_one, lvl_two, lvl_three

The actual training I’m a bit more uncertain about – especially if I’m implementing the losses correctly. If I understand it, there’s 1 loss function that is made up of the 3 losses from each FC branch. Also, the weights for each loss are changed depending on the epoch. The entire training loop looks like so:

def train_model(model, dataloader, criterion, optimizer, save_path_loss, save_path_acc, num_epochs=25):
    
    for epoch in range(num_epochs):

        #Custom learning rate scheduler
        if epoch > 53:
            optimizer = scheduler(optimizer, epoch)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            l1_running_corrects = 0
            l2_running_corrects = 0
            l3_running_corrects = 0

    #         # Iterate over data.
            for inputs, y1, y2, y3 in dataloader[phase]:

                inputs = inputs.to(device)
                l1_label = y1.to(device)
                l2_label = y2.to(device)
                l3_label = y3.to(device)
                
                optimizer.zero_grad()
                
                l1_out, l2_out, l3_out = model(inputs)
                
                with torch.set_grad_enabled(phase == 'train'):
                    
                    loss = losses(l1_out, l1_label, 
                                  l2_out, l2_label, 
                                  l3_out, l3_label, 
                                  criterion, epoch)
                    running_loss += loss.item() * inputs.size(0)
                    
                    l1_batch_corrects, \
                    l2_batch_corrects, \
                    l3_batch_corrects = batch_accs(l1_out, l1_label, 
                                                     l2_out, l2_label, 
                                                     l3_out, l3_label)
                    
                    l1_running_corrects += l1_batch_corrects
                    l2_running_corrects += l2_batch_corrects
                    l3_running_corrects += l3_batch_corrects
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

            epoch_loss = running_loss / len(dataloader[phase].dataset)
            l1_epoch_acc = l1_running_corrects.double() / len(dataloader[phase].dataset)
            l2_epoch_acc = l2_running_corrects.double() / len(dataloader[phase].dataset)
            l3_epoch_acc = l3_running_corrects.double() / len(dataloader[phase].dataset)
            print('{} Loss: {:.4f}\nLevel-1 Acc: {:.4f}\nLevel-2 Acc: {:.4f}\nLevel-3 Acc: {:.4f}'.format(phase, 
                                                                                                          epoch_loss, 
                                                                                                          l1_epoch_acc, 
                                                                                                          l2_epoch_acc, 
                                                                                                          l3_epoch_acc)
    
    return model

The function that calculates the losses is:

def losses(l1_out, l1_label, l2_out, l2_label, l3_out, l3_label, criterion, epoch):

    l1_loss = criterion(l1_out, l1_label)
    l2_loss = criterion(l2_out, l2_label)
    l3_loss = criterion(l3_out, l3_label)
    
    if epoch <= 12:
        total_loss = 0.98 * l1_loss + 0.01 * l2_loss + 0.01 * l3_loss
    if epoch > 12 and epoch <= 22:
        total_loss = 0.1 * l1_loss + 0.8 * l2_loss + 0.1 * l3_loss
    if epoch > 22 and epoch <= 32:
        total_loss = 0.1 * l1_loss + 0.2 * l2_loss + 0.7 * l3_loss
    if epoch > 32:
        total_loss = l3_loss
    
    return total_loss

And accuracies, in case I have something glaringly wrong here:

def batch_accs(l1_out, l1_label, l1_running_corrects, 
               l2_out, l2_label, l2_running_corrects, 
               l3_out, l3_label, l3_running_corrects):

    l1_pred = torch.argmax(l1_out, dim=1)
    l2_pred = torch.argmax(l2_out, dim=1)
    l3_pred = torch.argmax(l3_out, dim=1)
    
    l1_running_corrects += torch.sum(l1_pred == l1_label)
    l2_running_corrects += torch.sum(l2_pred == l2_label)
    l3_running_corrects += torch.sum(l3_pred == l3_label)
    
    return l1_running_corrects, l2_running_corrects, l3_running_corrects

I’m mainly concerned if I am implementing the losses correctly. For example, I read that it’s okay to sum separate losses and call .backward(), but I’m not 100% certain that I should do that rather than 3 separate losses and backwards.

Any feedback would be greatly appreciated!

Double post from here with potential answer.