How to memoize and reuse results of nn.Module in the forward pass?
class MagicModule(torch.nn.Module):
def __init__(self):
super(MagicModule, self).__init__()
def forward(self, x):
x = torch.norm(x)
return x
class ModelA(torch.nn.Module):
def __init__(self, magic_module):
super(ModelA, self).__init__()
self.magic_module = magic_module
def forward(self, x):
x = self.magic_module(x)
x = x ** 2
return x
class ModelB(torch.nn.Module):
def __init__(self, magic_module):
super(ModelB, self).__init__()
self.magic_module = magic_module
def forward(self, x):
x = self.magic_module(x)
x = x ** 3
return x
magic_module = MagicModule()
A = ModelA(magic_module)
B = ModelB(magic_module)
x = torch.rand(3, 4)
A(x)
B(x)
In the code example, magic_module(x) is called twice with the same result. In general, how can I memoize the result and make sure the same module shared by other modules/models will only be called once?