I have multiple LSTM encoders like this:
class PromptEncoder(torch.nn.Module):
def __init__(self,name, length,embedding_dim,id_offset, init_embs, prompt_ids,**kwargs) -> None:
super().__init__()
self.length = length
self.name = name
self.prompt_ids = prompt_ids
self.input_ids = torch.nn.parameter.Parameter(torch.tensor(prompt_ids),
requires_grad=False)
self.embedding_dim = embedding_dim
self.id_offset = id_offset
self.embedding = torch.nn.Embedding(length,embedding_dim)
self.net_inps = torch.nn.parameter.Parameter(torch.arange(length),
requires_grad=False)
self.lstm = torch.nn.LSTM(
input_size=embedding_dim,
hidden_size=embedding_dim //2, #my code
num_layers=2,
dropout=0,
bidirectional=True,
batch_first=True
)
self.mlp = torch.nn.Sequential(
torch.nn.Linear(embedding_dim, embedding_dim),
torch.nn.ReLU(),
torch.nn.Linear(embedding_dim, embedding_dim)
)
def forward(self,prompt_token_ids,pids=None):
# create embedding vectors for input ids
embeds = self.embedding(self.net_inps)
# do forward calculations
x = self.lstm(embeds.unsqueeze(0))
emblog.info("lstml embeds: %s",embeds)
running_weight = self.mlp(x[0]).squeeze(0)
prompt_token_ids = (prompt_token_ids.view(-1,1) == self.input_ids).int().argmax(dim=1)
# return weights for prompt_token_ids
return F.embedding(prompt_token_ids,running_weight)
some of prompt_ids
or input_ids
could be shared among them. However each encoder has its own embedding
matrix. They are all part of a container Module and are learned together. I want the shared ids point to a shared embedding so that if one changes, the change reflects to the embedding of all.
This is forward
wrapper:
def forward(self,input_ids, labels, decoder_input_ids=None,pids=None,**kwargs):
prompt_masks = self.prompt_token_fn(input_ids)
if prompt_masks.any():
input_ids_ = input_ids.clone()
if self.replacing_token_id is not None:
# replace prompt ids in input_ids with replacing token
input_ids_[prompt_masks]=self.replacing_token_id
# find the model embeddings of input ids except for prompt tokens
inputs_embeds = self.model_embeddings(input_ids_)
device=inputs_embeds.device
for encoder in self.prompt_encoders:
#encoder = self.prompt_encoders[0]
prompt_token_fn = encoder.get_prompt_token_fn()
encoder_masks = prompt_token_fn(input_ids)
if encoder_masks.any():
#find input ids for prompt tokens
prompt_input_ids = input_ids[encoder_masks]
# call forwards on prompt encoder whose outputs are prompt embeddings
prompt_embeds = encoder(prompt_input_ids,\
pids).to(device)
# replace prompt_embeddings calculated by prompt encoder in input embeddings
# in input embeds replace embeddings for prompt token with output of encoder
inputs_embeds[encoder_masks]=prompt_embeds
else:
inputs_embeds = self.model_embeddings(input_ids)
return self.underlying_model(inputs_embeds=inputs_embeds,**kwargs)
How can I implement such architecture?