Forward pass memoization

How to memoize and reuse results of nn.Module in the forward pass?

class MagicModule(torch.nn.Module):
    def __init__(self):
        super(MagicModule, self).__init__()
        
    def forward(self, x):
        x = torch.norm(x)
        return x
    
class ModelA(torch.nn.Module):
    def __init__(self, magic_module):
        super(ModelA, self).__init__()
        self.magic_module = magic_module
        
    def forward(self, x):
        x = self.magic_module(x)
        x = x ** 2
        return x

class ModelB(torch.nn.Module):
    def __init__(self, magic_module):
        super(ModelB, self).__init__()
        self.magic_module = magic_module
        
    def forward(self, x):
        x = self.magic_module(x)
        x = x ** 3
        return x

magic_module = MagicModule()
A = ModelA(magic_module)
B = ModelB(magic_module)

x = torch.rand(3, 4)
A(x)
B(x)

In the code example, magic_module(x) is called twice with the same result. In general, how can I memoize the result and make sure the same module shared by other modules/models will only be called once?

Iā€™m not sure if there is a better way than just moving magic_module outside of both models and call only once into magic_module.

1 Like

Hi thanks for your reply. I kind of trapped in a framework requiring me to structure models in a certain way, so I am not sure what you said is possible. But will take a look if it can hack through.