How to share weights with multple encoders

I have multiple LSTM encoders like this:

class PromptEncoder(torch.nn.Module):
    def __init__(self,name, length,embedding_dim,id_offset, init_embs, prompt_ids,**kwargs) -> None:
        self.length = length = name
        self.prompt_ids = prompt_ids
        self.input_ids = torch.nn.parameter.Parameter(torch.tensor(prompt_ids),
        self.embedding_dim = embedding_dim
        self.id_offset = id_offset
        self.embedding = torch.nn.Embedding(length,embedding_dim)
        self.net_inps = torch.nn.parameter.Parameter(torch.arange(length),
        self.lstm = torch.nn.LSTM(
            hidden_size=embedding_dim //2, #my code
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, embedding_dim),
            torch.nn.Linear(embedding_dim, embedding_dim)
    def forward(self,prompt_token_ids,pids=None):
        # create embedding vectors for input ids
        embeds = self.embedding(self.net_inps)
        # do forward calculations
        x = self.lstm(embeds.unsqueeze(0))"lstml embeds: %s",embeds)

        running_weight = self.mlp(x[0]).squeeze(0)

        prompt_token_ids = (prompt_token_ids.view(-1,1) == self.input_ids).int().argmax(dim=1)
        # return weights for prompt_token_ids 
        return F.embedding(prompt_token_ids,running_weight)

some of prompt_ids or input_ids could be shared among them. However each encoder has its own embedding matrix. They are all part of a container Module and are learned together. I want the shared ids point to a shared embedding so that if one changes, the change reflects to the embedding of all.
This is forward wrapper:

    def forward(self,input_ids, labels, decoder_input_ids=None,pids=None,**kwargs):
      prompt_masks = self.prompt_token_fn(input_ids)
        if prompt_masks.any():
            input_ids_ = input_ids.clone()
            if self.replacing_token_id is not None:
                # replace prompt ids in input_ids with replacing token
            # find the model embeddings of input ids except for prompt tokens
            inputs_embeds = self.model_embeddings(input_ids_)
            for encoder in self.prompt_encoders:
                #encoder = self.prompt_encoders[0]
                prompt_token_fn = encoder.get_prompt_token_fn()
                encoder_masks = prompt_token_fn(input_ids)
                if encoder_masks.any():
                    #find input ids for prompt tokens
                    prompt_input_ids = input_ids[encoder_masks]
                    # call forwards on prompt encoder whose outputs are prompt embeddings
                    prompt_embeds = encoder(prompt_input_ids,\
                    # replace prompt_embeddings calculated by prompt encoder in input embeddings
                    # in input embeds replace embeddings for prompt token with output of encoder
            inputs_embeds = self.model_embeddings(input_ids)
            return self.underlying_model(inputs_embeds=inputs_embeds,**kwargs)

How can I implement such architecture?

If the encoders are trained separately why would you want the changes to one encoder to get reflected in another
If you want that to happen, you have to train all of them as part of the same model.

You can use ModuleList and ParameterList to hold the models and the embeddings respectively

@andreasceid Thank you very much! The encoder are in a ModuleList. I put more of my code in the question including how they are called in the forward of the container Module. The container module actually wrap a transformer model (T5) which is freezed and the result of forward pass on encoders are fed into it. I am someway beginner with Pytorch and Transformer. For example some parts of my LSTM enoders such as having both input_id parameters and net_ids might be redundant?

Suppose input ids for whole model are [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, ....] in batch. the ids of encoder 1 are [3, 4] and the ids for encoder 2 are [1,3, 5]. where 3 is common between them. Each encoder must update embedding for its corresponding input. However, 3 exists in both of them. Maybe I should merge the results for 3 in the loop in forward, I don’t know. I thought maybe they must refer to a shared embedding space. Please guide me for this problem.

1 Like

@Ahmad_Pouramini this will not be an easy problem to solve. The best way forward would be to have no intersecting groups.

However, if you want to go ahead, you can use some sort of weight averaging

@anantguptadbl thanks, can’t I connect the common cells of two lstm to a feed forward and get one outupt and use that one?