Autoencoder with tied weights using sequential()

Hi,
I am building an autoencoder like this:

class autoencoder(torch.nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(dim_input, h_dim1),
            torch.nn.ReLU(),
            torch.nn.Dropout(),
            torch.nn.Linear(h_dim1, Z_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout())
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(Z_dim, h_dim1),
            torch.nn.ReLU(),
            torch.nn.Dropout(),
            torch.nn.Linear(h_dim1, dim_input),
            torch.nn.ReLU(),
            torch.nn.Dropout())

    def forward(self, X):
        Z = self.encoder(X)
        X = self.decoder(Z)
        return X, Z

But weights in the encoder and the decoder are different, how can I make it tied weights (weights in the decoder should be transpose of the encoder weights–parameters of the model are then only the encoder’s weights)?

Another question, in a tied weight autoencoder, if I use dropout for the encoder part for regularization, how should I apply it in the decoder side? To me, it seems in the above code different nodes are being dropped out in the encoder and the decoder sides.

Thanks!

I would be interested in understanding this as well!

@ptrblck i have the same question and same sequential auto-encoder , trying to achieve the same as original poster posted question here to tied weights between encoder and decoder layers

A clean approach would be to define the desired parameters once (via nn.Parameter in the __init__ method of the model) and apply them through the functional API in the forward method (e.g. via F.linear).

1 Like

@ptrblck can you please share some simple example?

Sure, here is a simple example of using F.linear with a weight parameter:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(4, 10))
        
    def forward(self, x):
        print(x.shape)
        x = F.linear(x, self.weight.T)
        print(x.shape)
        x = F.linear(x, self.weight)
        print(out.shape)
        return x

model = MyModel()
x = torch.randn(1, 4)
out = model(x)
> torch.Size([1, 4])
  torch.Size([1, 10])
  torch.Size([1, 4])
2 Likes

@ptrblck can you mark it as solution for other viewer, i am not original poster cant do it

1 Like