How to share a learnable parameter across different layers?

I want to create a model where I have a network-wide learnable parameter which I need to pass to each layer. I have thought of 2 ways of doing this:

(1)

class Func(nn.Module):
    def __init__(self, data_dim, hidden_dim1, hidden_dim2, target_dim):
        super().__init__()        
        self.lamb_mu = nn.Parameter(torch.Tensor(1).uniform_(1, 1))
        self.lamb_rho = nn.Parameter(torch.Tensor(1).uniform_(-6., -6.))

        self.l1 = Lasso_layer(data_dim, hidden_dim1, self.rho_prior)
        self.l2 = Lasso_layer(hidden_dim1, hidden_dim2, self.rho_prior)
        self.l4 = Lasso_layer(hidden_dim2, target_dim, self.rho_prior)

        self.lamb_sigma = None

    def forward(self, X):
        self.lamb_sigma = torch.log1p(torch.exp(self.lamb_rho))

        output = F.relu(self.l1(X.reshape(-1, 28*28), self.lamb_mu, self.lamb_sigma))
        output = F.relu(self.l2(output, self.lamb_mu, self.lamb_sigma))
        output = self.l4(output, self.lamb_mu, self.lamb_sigma)

        return output

(2)

class Func(nn.Module):
    def __init__(self, data_dim, hidden_dim1, hidden_dim2, target_dim):
        super().__init__()        
        self.lamb_mu = nn.Parameter(torch.Tensor(1).uniform_(1, 1))
        self.lamb_rho = nn.Parameter(torch.Tensor(1).uniform_(-6., -6.))

        self.l1 = Lasso_layer(data_dim, hidden_dim1, self.rho_prior, self.lamb_mu, self.lamb_rho)
        self.l2 = Lasso_layer(hidden_dim1, hidden_dim2, self.rho_prior, self.lamb_mu, self.lamb_rho)
        self.l4 = Lasso_layer(hidden_dim2, target_dim, self.rho_prior, self.lamb_mu, self.lamb_rho)

        self.lamb_sigma = None

    def forward(self, X):
        output = F.relu(self.l1(X.reshape(-1, 28*28)))
        output = F.relu(self.l2(output))
        output = self.l4(output)

        return output

Which of these approaches is correct and will update lamb_mu and lamb_rho parameters appropriately for the entire network?

(2) won’t work correctly I think, creating parameter clones. (1) is ok, you can also choose different ways to pass context, e.g.:

class Context(NamedTuple):
  lamb_mu : Tensor
  lamb_sigma : Tensor
...
def forward(self,...):
  sigma = ...
  ctx = Context(mu,sigma)
  ...

I think even (2) will work (provided self.rho_prior is defined before the layers are created), since objects of type nn.Parameter are mutable, and we are passing references to these objects into the Lasso_layer constructors. I don’t think parameters are being cloned here.

The one way to be sure is to construct a toy example and test … .

I do define self.rho_prior before creating lamb_mu and lamb_sigma as well as all the layers.

When I print the values of lamb_mu and lamb_sigma within the forward method of my Lasso_layer constructor. I see that those values stay the same across layers and only change after the loss.backward step occurs in both the approaches (1) and (2). So, I think both ways might be correct but I want to be absolutely certain before I produce results from this code.

yeah, seems like (1) works, but it seems brittle, relying on deduplication code everywhere (i.e. state_dict() has separate keys, but named_parameters() returns one key, and standard serialization correctly stores an object once too)

@ptrblck would you be kind enough to help me with this?