Multi Input Network MNIST-CIFAR

I have the following task of meta learning:

We want that our neural network learns to sum weights.

1)Do the training on MNIST, and on CIFAR10 (as support dataset). We want that performance (accuracy) on MNIST, before and after the sum, to be as close as possible. We use CIFAR10 only to try to learn the operation of sum of models. We are interested in the weights of MNIST, we can call theta_mnist, and we hope that they have learnt the operation of sum.
The way to do this is the following: have 3 NNs, with the same architecture, but with different weights, the first one for MNIST, second for CIFAR10, and the third one that have weights (for each layer) of theta_mnist+theta_cifar10. The loss function for this network must be CrossEntropy(MNIST) + CrossEntropy(CIFAR10) + CrossEntropy([MNIST+CIFAR10]-MNIST) (this last term of loss should guarantee that the network on MNIST have similar performances before and after the operation of sum.

2)Do the same thing of 1)replacing MNIST with SVHN and again CIFAR10 as support dataset.

3)Building a NN with weights: theta_mnist+theta_svhn, what are the performances?

Now, step by step, I have to start from point 1.
Let’s avoid the part about the loss function for now.
How can I do the part of summing weights?
I was thinking to:

  1. Create model1 and training it for some epochs on MNIST
  2. Create model2 (same architecture of model1) and training it for some epochs on CIFAR10
  3. Create modelcombo (same architecture of model1 and 2) and training it for some epochs on MNIST. The way to create modelcombo can be something like this:
class VGG16COMBO(nn.Module):
    
    def __init__(self, model1, model2, num_classes):
        super(VGG16COMBO, self).__init__()

        # calculate same padding:
        # (w - k + 2*p)/s + 1 = o
        # => p = (s(o-1) - w + k)/2

        self.block_1 = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=64,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      # (1(32-1)- 32 + 3)/2 = 1
                      padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )

        self.block_2 = nn.Sequential(
            nn.Conv2d(in_channels=64,
                      out_channels=128,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128,
                      out_channels=128,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )
        
        self.block_3 = nn.Sequential(
            nn.Conv2d(in_channels=128,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )

        self.block_4 = nn.Sequential(
            nn.Conv2d(in_channels=256,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        ) 


        self.classifier = nn.Sequential(
            nn.Linear(2048, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.25),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.25),
            nn.Linear(4096, num_classes),
        )

    def forward(self, mnist, cifar):

        mnist = self.block_1(mnist)
        mnist = self.block_2(mnist)
        mnist = self.block_3(mnist)
        mnist = self.block_4(mnist)
        mnist = mnist.view(m.size(0), -1)
        mnist = self.classifier(mnist)

        cifar = self.block_1(cifar)
        cifar = self.block_2(cifar)
        cifar = self.block_3(cifar)
        cifar = self.block_4(cifar)
        cifar = cifar.view(cifar.size(0), -1)
        cifar = self.classifier(cifar)

        x = torch.cat((mnist, cifar), dim=1)
        return x

Do you think this can be a good idea for now? My supervisor talked about multi input networks for this task, and I found similar solutions on the web.
Or is it a case of sharing weights? Do you have examples similar to this task?