How to load one model's output as another model's parameters and do end-to-end optimization

Here is an abstract of this problem:

Assume we have two models named Encoder and Decoder respectively, where Encoder is like:

class Encoder:
    def __init__(self, n, dim):
        self.embedding_A = nn.Embedding(n, dim)   # create an embedding matrix with shape (n, dim)
    def forward(self):
        out_embeddings = .... # do something with self.embedding_A, and get a new matrix with same shape
        return out_embeddings

while the Decoder has a format:

class Decoder:
    def __init__(self, n, dim):
        self.embedding_B = nn.Embedding(n, dim)  # same shape as Encoder's output
    def forward(self):
        # do something with its self.embedding_B and return a loss for backward

now I want the Decoder’s self.embedding_B to take the output of Encoder (i.e., the out_embeddings), then make an end-to-end optimization from Decoder’s output to Encoder’s self.embedding_A.

I know an easy way is changing the codes and directly using Encoder’s result as an input of Decoder.forward(). However, my Decoder model is too complex, making such a change is difficult.

Is there a possible way to load Encoder’s output (i.e., the out_embeddings) into Decoder’s self.embedding_B and chain two models together? In this way, the backward optimization is expected to start with Decoder’s output, go back to the self.embedding_B, then goes further back to optimize Encoder’s self.embedding_A

I have tried using some codes like

# *encoder* is an instance of class *Encoder*
# *decoder* is an instance of class *Decoder*

out_embedings = encoder(...) =  # my 1st try
decoder.embeddings_B = nn.Parameters(out_embeddings)  # my 2nd try

but both cannot make the back-optimization reach to encoder.embedding_A. Are there any ways to do that by this direct assigning?


  • We will have to remove the embedding parameter in the Decoder init and accept it as part of the Decoder forward function

  • We will have to get the input indices in the Decoder Forward as well so that we know which embedding value to use for the processing

import numpy as np
import pandas as pd
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

class Encoder(nn.Module):
    def __init__(self, num_indices):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(num_indices, 100) 
        self.l1 = nn.Linear(200+100, 64)
        self.l2 = nn.Linear(64, 32)
        self.b1 = nn.BatchNorm1d(64)
        self.b2 = nn.BatchNorm1d(32)
    def forward(self, x, indices):
        cur_embedding = self.embedding(indices)
        x = torch.hstack((x, cur_embedding))
        x = F.relu(self.b1(self.l1(x)))
        x = self.b2(self.l2(x))
        return x, self.embedding

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.l1 = nn.Linear(100 + 32, 200)
        self.b1 = nn.BatchNorm1d(200)
    def forward(self, x, embedding, indices):
        self.embedding_B = embedding(indices)
        x= torch.hstack((x, self.embedding_B))
        x = self.b1(self.l1(x))
        return x
num_indices = 100
a = pd.DataFrame(np.random.randint(num_indices, size=(1000)), columns=['index'])
a['vector'] = np.random.rand(1000, 200).tolist()

encoder = Encoder(num_indices)
decoder = Decoder()

print("Models compiled")

loss = torch.nn.MSELoss()
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.001)

X = torch.FloatTensor(a['vector'].values.tolist())
indices = torch.LongTensor(a['index'].values.tolist())

for cur_epoch in range(10):
    encoder_output, embeddings = encoder(X, indices)
    decoder_output = decoder(encoder_output, embeddings, indices)
    cur_loss = loss(decoder_output, X)
    print("Epoch {0} Loss is {1}".format(cur_epoch, cur_loss.item()))