Strange torch.nn.Module Behaviour with Multiple Inheritance

I am trying to create a module class that inherits from two classes as follows:

from torch import nn

class Module1(nn.Module):
	def __init__(self):
		nn.Module.__init__(self)

		self.l1_loss = nn.L1Loss()


class Module2(nn.Module):
	def __init__(self):
		nn.Module.__init__(self)


class Module12(Module1, Module2):
	def __init__(self):
		Module1.__init__(self)

		if hasattr(self, 'l1_loss'):
			print('l1_loss at position 1')

		Module2.__init__(self)

		if hasattr(self, 'l1_loss'):
			print('l1_loss at position 2')

When I run the code, “l1_loss at position 1” is printed but not “l1_loss at position 2”. It seems that when Module2.init(self) is called, the member variable self.l1_loss is deleted. Why? how can I create a module with multiple inheritance?

I think you could use:

class Module12(Module1, Module2):
	def __init__(self):
		super(Module12, self).__init__()
		if hasattr(self, 'l1_loss'):
			print('l1_loss at position 1')

		if hasattr(self, 'l1_loss'):
			print('l1_loss at position 2')

m = Module12()

which will print both statements.

1 Like

if the example is a bit more complicated and you actually have to call the init methods separately, you can try something like

class Module1(nn.Module):
	def __init__(self):
		if not hasattr(self, "__module_initialized"):
			nn.Module.__init__(self)
			self.__module_initialized = True

		self.l1_loss = nn.L1Loss()


class Module2(nn.Module):
	def __init__(self):
		if not hasattr(self, "__module_initialized"):
			nn.Module.__init__(self)
			self.__module_initialized = True

though I have to say, that your usecase looks a bit strange to me. I had a similar problem which came however from a diamond inheritance structure. In your case I wonder, if it might be better to let Module12 have two members of types Module1 and Module2