Combination of features of convolutional layers channel-by-channel in a multi-branch model

The convolutional model presented below, has two branches and each branch (for example) has two stages (convolutional layers).
My aim is to combine the weighted feature maps (channels) of the first convolutional layer from the second branch with the channels of the first convolutional layer from the first branch.

I want to extract the channels from the first convolutional layer in the second branch, multiply it by a weight (weight is a class in the code that makes the output a weighted version of its input) and stack it with the channels of its counterpart convolutional layer from the first branch. Afterwards, by utilizing a 1x1 conv2d, the size of the stacked feature maps will be changed to its initial size and this combined channels should be used by the first branch and the next convolutional layers will be computed based on these combined channels. After that, I want to have this kind of combination between the second convolutional layers of the branches. (In other words, I want to combine features channel-by-channels between branches.)

Please find the main_class (the whole model that consists of two branches) and the first_branch and second_branch below:

class main_class(nn.Module):
    def __init__(self, pretrained=False):
        super(main_class, self).__init__()
        self.input=input_data()  # input_data is a class the provides the input data for the each branch
        
        self.conv_t2 = BasicConv3d(...........)
        self.second_branch=second_branch(512, out_sigmoid=True)
        
        self.conv_t1 = BasicConv3d(..............)
        self.first_branch=first_branch(512, out_sigmoid=True)
  
        self.last = nn.Conv2d(4, 1, kernel_size=1, stride=1)
        
        self.sigmoid = nn.Sigmoid()

        def forward(self, x, par = False):
        
         x1, x2 = self.input(x)
        
        #second branch
        y2 = self.conv_t2(x2)
        out2 = self.second_branch(y2)
        
        #first branch
        y1 = self.conv_t1(x1)
        out1 = self.first_branch(y1)
        
        x = torch.cat((out2, out1), 1)
        x = self.last(x)
        out = self.sigmoid(x)
        
        if par:
            return  out1, out2, out
        
        return out

The first_branch:

class first_branch(nn.Module):
    def __init__(self, in_channel=512, out_channel=[380, 200], out_sigmoid=False):
        super(first_branch, self).__init__()
        
        self.out_sigmoid=out_sigmoid

        self.deconvlayer1_2 = self._make_deconv(in_channel, out_channel[0], num_conv=3)
        self.upsample1_2=Upsample(scale_factor=2, mode='bilinear')
        self.combined1_2 = nn.conv2d(720, 380, kernel_size=1, stride=1, padding=0)

        self.deconvlayer1_1 = self._make_deconv(out_channel[0], out_channel[1], num_conv=3)
        self.upsample1_1=Upsample(scale_factor=2, mode='bilinear')
        self.combined1_1 = nn.conv2d(400, 200, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x=self.deconvlayer1_2(x)
        x = self.upsample1_2(x)
        x=self.deconvlayer1_1(x)
        x = self.upsample1_1(x)
        
        if self.out_sigmoid:
            x=self.sigmoid(x)
        
        return x

The second_branch:

class second_branch(nn.Module):
    def __init__(self, in_channel=512, out_channel=[380,200], out_sigmoid=False):
        super(second_branch, self).__init__()
        
        self.out_sigmoid=out_sigmoid
        
        self.weight = weight()  # weight is a class that weighted its input

        self.deconvlayer2_2 = self._make_deconv(in_channel, out_channel[0], num_conv=3)
        self.upsample2_2=Upsample(scale_factor=2, mode='bilinear')
        self.deconvlayer2_! = self._make_deconv(out_channel[0], out_channel[1], num_conv=3)
        self.upsample2_1=Upsample(scale_factor=2, mode='bilinear')
        
    
    def forward(self, x):
        x=self.deconvlayer2_2(x)
        x = self.upsample2_2(x)
        weighted2_2 = self.weight(x)

        x=self.deconvlayer2_1(x)
        x = self.upsample2_1(x)
        weighted2_1 = self.weight(x)

        
        if self.out_sigmoid:
            x=self.sigmoid(x)
        
        return x, weighted2_1, weighted2_2

For implementing the mentioned idea, I modified the main_class as follows (instead of using the first_branch class in the forward function of the main_class, I wrote the script lines of the forward function of the first_branch in the forward function of the main_class):

class main_class(nn.Module):
    def __init__(self, pretrained=False):
        super(main_class, self).__init__()
        self.input=input_data()  # input_data is a class the provides the input data for the each branch
        
        self.conv_t2 = BasicConv3d(....................)
        self.second_branch=second_branch(512, out_sigmoid=True)
        
        self.conv_t1 = BasicConv3d(............)
        self.first_branch=first_branch(512, out_sigmoid=True)
  
        self.last = nn.Conv2d(4, 1, kernel_size=1, stride=1)
        
        self.sigmoid = nn.Sigmoid()

        def forward(self, x, par = False):
        
        x1, x2 = self.input(x)
        
        #second branch
        y2 = self.conv_t2(x2)

        out2, weighted2_1, weighted2_2 = self.second_branch(y2)

        
        #first branch
        y1 = self.conv_t1(x1)

       # instead of using from class first_branch, again I write the script lines of first_branch.forward() in below:
        x=self.deconvlayer1_2(y1)
        x = self.upsample1_2(x)
        stacking_2 = torch.stack(x, weighted2_2)
        x = self.frist_branch.combined1_2(stacking_2)


        x=self.deconvlayer1_1(x)
        x = self.upsample1_1(x)
        stacking_1 = torch.stack(x, weighted2_1)
        x = self.frist_branch.combined1_1(stacking_1)
        
        out1=self.sigmoid(x)


        
        x = torch.cat((out2, out1), 1)
        x = self.last(x)
        out = self.sigmoid(x)
        
        if par:
            return  out1, out2, out
        
        return out

I faced with the following error:

TypeError: Cannot create a consistent method resolution order (MRO) for bases Module, second_branch

How can I fix this problem and how can I make the code able to have the interactions between new branches that may be added later to the model (for example if I have three branches, how can I have this kind of data combination between the third branch and the second one, and between the output of the previous combination and the first branch)?

Did you try:
super(my_main_class, self).init()

Not saying this will help, but…

Thanks for your reply. In the question I named the modified class to my_main_class to be more clear but I have used your idea in the source code. In other words, my_main_class is the main_class that the idea is applied on that. The name of the the modified class is change to avoid ambiguity.