Multiple references to nn.Parameter in different modules


I would like to know how can I keep a reference to a nn.Parameter in a module and a submodule, as both of them use this parameter and I don’t want to pass them as argument.

For example:

class SubModule(nn.Module):
    def __init__(self,X):
            super(SubModule, self).__init__()
            self.X = X

class BigModule(nn.Module):
     def __init__(self):
           super(BigModule, self).__init__()
           self.X = nn.Parameter(torch.ones(10,10))
           self.submodules = nn.ModuleList([SubModule(self.X),SubModule2(self.X)])

Python offers no elegant solutions for this, in addition your snippet would create a second X parameter (because nn.Module.__setattr__ intercepts nn.Parameter assignments)

I use state object pattern (argument to forward()) for this, this has two pros: 1)works with JIT 2)tensors and parameters are interchangeable. Downsides: 1)state argument must be added to submodules 2)some boilerplate to initialize the state object.

For simpler scripts, I’d consider global variables:

X = None
class BigModule(nn.Module):
     def __init__(self):
           global X
           super(BigModule, self).__init__()
           self.X = X = nn.Parameter(torch.ones(10,10))
1 Like

Thank you. Just wanted to confirm my guess, which seems to match yours.