I am trying to create a kind of a siamese LSTM model to find the relevance of a search result to the query. I created the model, but while training the parameters are not getting updated after the backward pass. I checked the gradients are they are not 0. But the parameters before and after the backward pass are the same. Below is my code.
Can someone please help me here?
class LSTMencoder (nn.Module):
def __init__(self, weights_matrix, hidden_dim_q_1,hidden_dim_q_2,hidden_dim_d_1,hidden_dim_d_2,device):
super(LSTMencoder,self).__init__()
self.embedding_token, num_embeddings, embedding_dim = create_emb_layer(weights_matrix)
self.lstm_q_1 = nn.LSTM(embedding_dim, hidden_dim_q_1,batch_first = True)
self.lstm_q_2 = nn.LSTM(hidden_dim_q_1,hidden_dim_q_2,batch_first = True)
self.lstm_d_1 = nn.LSTM(embedding_dim, hidden_dim_d_1,batch_first = True)
self.lstm_d_2 = nn.LSTM(hidden_dim_d_1,hidden_dim_d_2,batch_first = True)
self.final_hidden_q = hidden_dim_q_2
self.final_hidden_d = hidden_dim_d_2
self.device = device
def forward(self,query_batch,desc_batch,q_lengths,d_lengths,q_ix,d_ix):
query_embedding = self.embedding_token(query_batch)
desc_embedding = self.embedding_token(desc_batch)
q_packed = PACK(query_embedding,q_lengths, batch_first=True)
d_packed = PACK(desc_embedding,d_lengths,batch_first=True)
output_q1, _ = self.lstm_q_1(q_packed)
output_q2, _ = self.lstm_q_2(output_q1)
output_d1, _ = self.lstm_d_1(d_packed)
output_d2, _ = self.lstm_d_2(output_d1)
output_q_padded, output_q_lengths = pad_packed_sequence(output_q2, batch_first=True)
output_d_padded, output_d_lengths = pad_packed_sequence(output_d2, batch_first=True)
idx_q = (torch.LongTensor(q_lengths) - 1).view(-1, 1).expand(len(q_lengths), output_q_padded.size(2))
idx_d = (torch.LongTensor(d_lengths) - 1).view(-1, 1).expand(len(d_lengths), output_d_padded.size(2))
time_dimension = 1
idx_q = idx_q.unsqueeze(time_dimension)
idx_d = idx_d.unsqueeze(time_dimension)
if (output_q_padded.is_cuda or output_d_padded.is_cuda):
idx_q = idx_q.cuda(output_q_padded.data.get_device())
idx_d = idx_d.cuda(output_q_padded.data.get_device())
last_output_q = output_q_padded.gather(time_dimension, Variable(idx_q)).squeeze(time_dimension)
last_output_d = output_d_padded.gather(time_dimension, Variable(idx_d)).squeeze(time_dimension)
last_output_q = last_output_q.to(self.device)
last_output_d = last_output_d.to(self.device)
last_output_query = torch.zeros_like(last_output_q).scatter_(0, Variable(q_ix).unsqueeze(1).expand(-1, last_output_q.shape[1]), last_output_q)
last_output_desc = torch.zeros_like(last_output_d).scatter_(0, Variable(d_ix).unsqueeze(1).expand(-1, last_output_d.shape[1]), last_output_d)
return last_output_q,last_output_d,last_output_query,last_output_desc
After the LSTM forward pass I am taking the final time step feature vector of query and desc using the gather function and then in the last step I am using scatter funtion to reorder my query and desc batch.
class Siamese_lstm(nn.Module):
def __init__(self, weights_matrix, hidden_dim_q_1,hidden_dim_q_2,hidden_dim_d_1,hidden_dim_d_2,label_size,device):
super(Siamese_lstm, self).__init__()
self.encoder = LSTMencoder(weights_matrix,hidden_dim_q_1,hidden_dim_q_2,hidden_dim_d_1,hidden_dim_d_2,device)
self.feature_len =self.encoder.final_hidden_q + self.encoder.final_hidden_d
self.final_layer = nn.Linear(self.feature_len,512)
self.final_layer_1 = nn.Linear(512,label_size)
def forward(self,query_batch,desc_batch,q_lengths,d_lengths,q_ix,d_ix):
_,_,query_feature,desc_feature = self.encoder(query_batch,desc_batch,q_lengths,d_lengths,q_ix,d_ix)
final_vector = torch.cat((query_feature,desc_feature),1)
output = self.final_layer(final_vector)
output_1 = self.final_layer_1(output)
return output_1
Once the LSTM encoder encodes the query and desc, they are concatenated to form a feature vector and then passed to two fc layer to predict the output
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch+1, num_epochs))
print('-' * 10)
for phase in ['train','val']:
if (phase == 'train'):
scheduler.step()
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0.0
for query_batch,desc_batch,query_lengths,desc_lengths,q_sorted_idx,d_sorted_idx,labels in data_loaders[phase]:
query_batch = query_batch.to(device)
desc_batch = desc_batch.to(device)
labels = labels.to(device)
q_sorted_idx = q_sorted_idx.to(device)
d_sorted_idx = d_sorted_idx.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase=='train'):
outputs = model(query_batch,desc_batch,query_lengths,desc_lengths,q_sorted_idx,d_sorted_idx)
_,preds = torch.max(outputs,1)
loss = criterion(outputs,labels)
if (phase == 'train'):
#loss.requires_grad = True
loss.backward()
optimizer.step()
running_loss += loss.item() * query_batch.size(0)
running_corrects += torch.sum(preds == labels)
epoch_loss = running_loss / data_lengths[phase]
epoch_acc = running_corrects.double() / data_lengths[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print ()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best Accuracy: {:4f}'.format(best_acc))
model.load_state_dict(best_model_wts)
Loss function used is cross entropy and optimizer used is Adam