I am trying to create a module class that inherits from two classes as follows:
from torch import nn
class Module1(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.l1_loss = nn.L1Loss()
class Module2(nn.Module):
def __init__(self):
nn.Module.__init__(self)
class Module12(Module1, Module2):
def __init__(self):
Module1.__init__(self)
if hasattr(self, 'l1_loss'):
print('l1_loss at position 1')
Module2.__init__(self)
if hasattr(self, 'l1_loss'):
print('l1_loss at position 2')
When I run the code, “l1_loss at position 1” is printed but not “l1_loss at position 2”. It seems that when Module2.init(self) is called, the member variable self.l1_loss is deleted. Why? how can I create a module with multiple inheritance?
class Module12(Module1, Module2):
def __init__(self):
super(Module12, self).__init__()
if hasattr(self, 'l1_loss'):
print('l1_loss at position 1')
if hasattr(self, 'l1_loss'):
print('l1_loss at position 2')
m = Module12()
if the example is a bit more complicated and you actually have to call the init methods separately, you can try something like
class Module1(nn.Module):
def __init__(self):
if not hasattr(self, "__module_initialized"):
nn.Module.__init__(self)
self.__module_initialized = True
self.l1_loss = nn.L1Loss()
class Module2(nn.Module):
def __init__(self):
if not hasattr(self, "__module_initialized"):
nn.Module.__init__(self)
self.__module_initialized = True
though I have to say, that your usecase looks a bit strange to me. I had a similar problem which came however from a diamond inheritance structure. In your case I wonder, if it might be better to let Module12 have two members of types Module1 and Module2