Hello! I need to pretrain embedding layer of a model in a self-supervised manner and then use this pretrained embedding layer in the other model with a different structure. Both models have embedding layer as the first layer. Is there is a way to transfer weights of this layer from the pretrained model EHR_Embedding() to my other model LSTM_model() ?
Is it enough to just assign the weighs in the following way (as follows from this post: https://discuss.pytorch.org/t/copying-part-of-the-weights/14199?u=maslenkovas )
model.embedding.weight.data = pretrained_model.embedding.weight.data
The classes of both models are below.
Thank you
class EHR_Embedding(nn.Module):
def __init__(self, embedding_size, vocab_size=15463, drop=0.1):
super(EHR_Embedding, self).__init__()
self.embedding_size = embedding_size
self.vocab_size = vocab_size
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
self.projection = nn.Sequential(
nn.ReLU(),
nn.Linear(in_features=embedding_size, out_features=embedding_size)
)
self.drop = nn.Dropout(p=drop)
def forward(self, tensor_demo, tensor_med, tensor_vitals, tensor_labs):
batch_size = tensor_med.size()[0]
# first traansformation
emb_demo_X = self.drop(self.embedding(tensor_demo.squeeze(1)))
emb_med_X = self.drop(self.embedding(tensor_med[:,:].squeeze(1)))
emb_vitals_X = self.drop(self.embedding(tensor_vitals[:,:].squeeze(1)))
emb_labs_X = self.drop(self.embedding(tensor_labs[:,:].squeeze(1)))
projection_demo_X = self.projection(emb_demo_X)
projection_med_X = self.projection(emb_med_X)
projection_vitals_X = self.projection(emb_vitals_X)
projection_labs_X = self.projection(emb_labs_X)
embedding_X = (emb_demo_X, emb_med_X, emb_vitals_X, emb_labs_X)
projection_X = (projection_demo_X, projection_med_X, projection_vitals_X, projection_labs_X)
# second transformation
emb_demo_Y = self.drop(self.embedding(tensor_demo.squeeze(1)))
emb_med_Y = self.drop(self.embedding(tensor_med[:,:].squeeze(1)))
emb_vitals_Y = self.drop(self.embedding(tensor_vitals[:,:].squeeze(1)))
emb_labs_Y = self.drop(self.embedding(tensor_labs[:,:].squeeze(1)))
projection_demo_Y = self.projection(emb_demo_Y)
projection_med_Y = self.projection(emb_med_Y)
projection_vitals_Y = self.projection(emb_vitals_Y)
projection_labs_Y = self.projection(emb_labs_Y)
embedding_Y = (emb_demo_Y, emb_med_Y, emb_vitals_Y, emb_labs_Y)
projection_Y = (projection_demo_Y, projection_med_Y, projection_vitals_Y, projection_labs_Y)
return embedding_X, projection_X, embedding_Y, projection_Y
class LSTM_model(nn.Module):
def __init__(self, pretrained_model, H=128, max_length=max_length, max_day=7, vocab_size=vocab_size, embedding_size=200):
super(LSTM_model, self).__init__()
# Hyperparameters
self.max_day = max_day
L = (self.max_day+1) * (256 + 256 + 512) + 1280
self.H = H
self.max_length = max_length
self.embedding_size = embedding_size
self.vocab_size = vocab_size
self.embedding = pretrained_model
self.fc_med = nn.Linear(max_length['medications'] * 2 * self.H, 256) #65,280
self.fc_vit = nn.Linear(max_length['vitals'] * 2 * self.H, 256) #51,200
self.fc_lab = nn.Linear(max_length['lab_tests'] * 2 * self.H, 512) #102,400
self.lstm_day = nn.LSTM(input_size=embedding_size,
hidden_size=self.H,
num_layers=1,
batch_first=True,
bidirectional=True)
self.fc_1 = nn.Linear((self.max_day+1) * (256 + 256 + 512) + max_length['demographics']*2*H, 2048)
self.lstm_adm = nn.LSTM(input_size=2048,
hidden_size=self.H,
num_layers=1,
batch_first=True,
bidirectional=False)
self.drop = nn.Dropout(p=0.5)
self.fc_2 = nn.Linear(self.H, (self.max_day+1))
# self.sigmoid = nn.Sigmoid()
def forward(self, tensor_demo, tensor_med, tensor_vitals, tensor_labs):
batch_size = tensor_med.size()[0]
days = self.max_day + 1
out_emb_med_demo = self.embedding(tensor_demo.squeeze(1))
output_lstm_day_demo, _ = self.lstm_day(out_emb_med_demo)
full_output = output_lstm_day_demo.reshape(batch_size, self.max_length['demographics']* 2 * self.H)
for d in range(days):
# embedding layer applied to all tensors
out_emb_med = self.embedding(tensor_med[:, d, :].squeeze(1))
out_emb_vitals = self.embedding(tensor_vitals[:, d, :].squeeze(1))
out_emb_labs = self.embedding(tensor_labs[:, d, :].squeeze(1))
# lstm layer applied to embedded tensors
output_lstm_day_med = self.fc_med(\
self.lstm_day(out_emb_med)[0]\
.reshape(batch_size, max_length['medications'] * 2 * self.H))
output_lstm_day_vitals = self.fc_vit(\
self.lstm_day(out_emb_vitals)[0]\
.reshape(batch_size, max_length['vitals'] * 2 * self.H))
output_lstm_day_labs = self.fc_lab(\
self.lstm_day(out_emb_labs)[0]\
.reshape(batch_size, max_length['lab_tests']* 2 * self.H))
# concatenate for all 26 days
full_output = torch.cat((full_output, \
output_lstm_day_med,\
output_lstm_day_vitals,\
output_lstm_day_labs), dim=1)
# print('full_output size: ', full_output.size())
output = self.fc_1(full_output)
output, _ = self.lstm_adm(output)
output = self.drop(output)
output = self.fc_2(output)
output = torch.squeeze(output, 1)
# if self.criterion == 'BCELoss':
# output = self.sigmoid(output)
return output