I am trying to create a module class that inherits from two classes as follows:
from torch import nn
class Module1(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.l1_loss = nn.L1Loss()
class Module2(nn.Module):
def __init__(self, damy):
nn.Module.__init__(self)
self.damy = damy
class Module12(Module1, Module2):
def __init__(self, damy):
Module1.__init__(self)
if hasattr(self, 'l1_loss'):
print('l1_loss at position 1')
Module2.__init__(self, damy)
if hasattr(self, 'l1_loss'):
print('l1_loss at position 2')
When I run the code, “l1_loss at position 1” is printed but not “l1_loss at position 2”. It seems that when Module2.init(self) is called, the member variable self.l1_loss is deleted. How can I create a module with multiple inheritance?
A problem very similar to this one is discussed in this issue and is solved by using the built-in function super (). However, in this code, the number of arguments given to the init method of Module1 and Module2 is different, so I think that super cannot be used.
You should not call Module1.__init__() or even nn.Module.__init__() in your own function but super().__init__().
This will make sure that the multiple inheritance is properly handled.
Following your advice, I tried to inherit the parent class of Module1 and Module2 with super () .__ init__ (), but now an error occurs. Here is the code and results.
from torch import nn
class Module1(nn.Module):
def __init__(self):
super().__init__()
self.l1_loss = nn.L1Loss()
class Module2(nn.Module):
def __init__(self, damy):
super().__init__()
self.damy = damy
class Module12(Module1, Module2):
def __init__(self, damy):
Module1.__init__(self)
print(f"l1_loss at position 1:{hasattr(self, 'l1_loss')}")
Module2.__init__(self, damy)
print(f"l1_loss at position 2:{hasattr(self, 'l1_loss')}")
print(f"damy at position 2:{hasattr(self, 'damy')}")
module12 = Module12("damy")
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-24-e849cf8c32a4> in <module>()
21 print(f"damy at position 2:{hasattr(self, 'damy')}")
22
---> 23 module12 = Module12("damy")
1 frames
<ipython-input-24-e849cf8c32a4> in __init__(self, damy)
14 class Module12(Module1, Module2):
15 def __init__(self, damy):
---> 16 Module1.__init__(self)
17 print(f"l1_loss at position 1:{hasattr(self, 'l1_loss')}")
18
<ipython-input-24-e849cf8c32a4> in __init__(self)
3 class Module1(nn.Module):
4 def __init__(self):
----> 5 super().__init__()
6 self.l1_loss = nn.L1Loss()
7
TypeError: __init__() missing 1 required positional argument: 'damy'
I guess it’s because the super () of Module1 that is called when Module12 does __init __ () of Module1 is reading the other parent class Module2 of the child class Module12.
In Module12, we can’t use super () because each parent class has a different number of arguments. If there is ways to solve this problem, please let me know.