Multi-inheritance problem

Hi,

I need to implement a class A, and two children class B and C, and a grand-child D, The code is like this:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as modelzoo
import torch.distributed as dist



class A(nn.Module):

    def __init__(self, ratio=4, *args, **kwargs):
        super().__init__()

        self.conv_base = nn.Conv2d(3, 3 * ratio, 3, 1, 1)


class B(A):

    def __init__(self, b_args, **kwargs):
        super().__init__(ratio=4)

        self.conv1 = nn.Conv2d(4, 3, 1, 1, 0)


class C(A):

    def __init__(self, c_args, **kwargs):
        super().__init__(ratio=4)

        self.conv2 = nn.Conv2d(4, 3, 1, 1, 0)


class D(B, C):

    def __init__(self, b_args, c_args):
        super().__init__(b_args=b_args, c_args=c_args)

        self.conv3 = nn.Conv2d(4, 3, 1, 1, 0)


b_args = dict(a=1)
c_args = dict(b=2)
model = D(b_args, c_args)

And I observed this error message:

    super().__init__(ratio=4)
TypeError: C.__init__() missing 1 required positional argument: 'c_args'

Would you tell me how could I make this work?

The issue arises because D inherits from both B and C, and when super() is used in D.__init__, Python’s method resolution order (MRO) calls B.__init__, which subsequently calls A.__init__. However, C.__init__ is not automatically called, so its required c_args is missing.

To fix this, explicitly call the __init__ methods of both parent classes (B and C) within D.__init__.
import torch
import torch.nn as nn

class A(nn.Module):
def init(self, ratio=4, *args, **kwargs):
super().init()
self.conv_base = nn.Conv2d(3, 3 * ratio, 3, 1, 1)

class B(A):
def init(self, b_args, **kwargs):
super().init(ratio=4, **kwargs)
self.conv1 = nn.Conv2d(4, 3, 1, 1, 0)

class C(A):
def init(self, c_args, **kwargs):
super().init(ratio=4, **kwargs)
self.conv2 = nn.Conv2d(4, 3, 1, 1, 0)

class D(B, C):
def init(self, b_args, c_args):
B.init(self, b_args=b_args)
C.init(self, c_args=c_args)
self.conv3 = nn.Conv2d(4, 3, 1, 1, 0)

Example usage

b_args = dict(a=1)
c_args = dict(b=2)
model = D(b_args, c_args)
Explicit Parent Initialization: In D.__init__, explicitly call B.__init__ and C.__init__ with their respective arguments.
Avoid Super with Multiple Inheritance: Using super() in complex multiple inheritance setups can lead to skipped initializations due to MRO. Explicit calls ensure all parent classes are initialized properly.

This structure ensures all required arguments (b_args and c_args) are passed and handled correctly.

I still have error message:

Traceback (most recent call last):
  File "tmp.py", line 101, in <module>
    model = D(b_args, c_args)
  File "tmp.py", line 93, in __init__
    B.__init__(self, b_args=b_args)
  File "
TypeError: C.__init__() missing 1 required positional argument: 'c_args'

Though I have changed my code like this:

class A(nn.Module):

    def __init__(self, ratio=4, *args, **kwargs):
        super().__init__()

        self.conv_base = nn.Conv2d(3, 3 * ratio, 3, 1, 1)


class B(A):

    def __init__(self, b_args, **kwargs):
        super().__init__(ratio=4)

        self.conv1 = nn.Conv2d(4, 3, 1, 1, 0)


class C(A):

    def __init__(self, c_args, **kwargs):
        super().__init__(ratio=4)

        self.conv2 = nn.Conv2d(4, 3, 1, 1, 0)


class D(B, C):

    def __init__(self, b_args, c_args):
        #  super().__init__(b_args=b_args, c_args=c_args)
        B.__init__(self, b_args=b_args)
        C.__init__(self, c_args=c_args)

        self.conv3 = nn.Conv2d(4, 3, 1, 1, 0)


b_args = dict(a=1)
c_args = dict(b=2)
model = D(b_args, c_args)