Weighting the loss of two networks

I am doing image classification where I have two input networks, of which the output nodes are concatenated in a fully connected layer. This is my code:

class resnet18_meteo(nn.Module):
    def __init__(self, resnet18, meteo_NN, num_classes):
        super(resnet18_meteo, self).__init__()
        
        # Respectively a torchvision resnet-18 and a 1-hidden layer NN
        self.resnet_CNN = resnet18
        self.meteo_net = meteo_NN
        
        # Sizes of the FC layers of both NN's
        self.len_fc_resnet = self.resnet_CNN.fc.in_features
        self.len_fc_meteo = self.meteo_net.fc_last.out_features
        print(self.len_fc_meteo)
        
        # Remove FC layer from the resnet 
        self.modules=list(self.resnet_CNN.children())[:-1]
        self.resnet18_convblocks= nn.Sequential(*self.modules)
        
        # Fully connected layer is now size resnet FC + meteo FC
        self.fc = nn.Linear(self.len_fc_resnet + self.len_fc_meteo, num_classes)
    
    def forward(self, img_x, meteo_x):
        
        # Both should be flattened layers at end of networks
        img_x = self.resnet18_convblocks(img_x)
        meteo_x = self.meteo_net(meteo_x)
        
        # Flatten convolutional features
        img_x_flattened = img_x.view(img_x.size(0), -1)
        
        # Concat the outputs of CNN and meteo-NN in fully connected layer
        out = torch.cat([img_x_flattened, meteo_x], dim=1)
        out = self.fc(out)
        return out   

I use a resnet 18 structure to input the images. Besides, I have four additional non-image features that I input to a small feedforward network. As can be seen, the outputs of these networks (size 512 + 32) are thus concatenated in the FC layer.

Now I was wondering if I can give some kind of ‘judgement weights’ to these networks. For instance, I’d like to have the resnet have 70% of the influence on the decision/loss, and the non-image net have 30% of the influence.

The only way I can think of doing this is by totally disconnecting these networks from each other, calculate the losses for each data-point (for each network), and then weighting these losses. However, I would like to know if this type of thing could be incorporated into my current network structure.

As you forward the composed features (in the variable out) through another fully connected layer you could simply weight them before using torch.cat(). This has however the disadvantage that you don’t know how this will affect the parameters of the fully connected layer during training.

Even if you have to disconnect the networks I would recommend to calculate separate losses.