I am trying to code a simple NER model (BiLSTM) with character level embeddings (also modelled using BiLSTM). The idea to concatenate character embedding (computed from BiLSTM) with the word embeddings, this concatenated tensor is fed to the BiLSTM to label sequence. In my current implementation I am using for-loop
to compute character representation of every word token, is there a way I can avoid this loop?
The shape of char_seq
is: batch_size
, seq_len
, token_len (max_chars_in_a_token)
.
import torch
import torch.nn as nn
import torch.nn.functional as f
class Net(nn.Module):
def __init__(self, params):
super(Net, self).__init__()
self.word_embedding = nn.Embedding(params.word_vocab_size, params.word_embedding_dim)
self.char_embedding = nn.Embedding(params.char_vocab_size, params.char_embedding_dim)
self.pos_embedding = nn.Embedding(params.pos_size, params.pos_embedding_dim)
self.cap_embedding = nn.Embedding(params.cap_size, params.cap_embedding_dim)
self.c_lstm = nn.LSTM(params.char_embedding_dim, params.char_hidden_dim, bidirectional=True, batch_first=True)
self.lstm = nn.LSTM(params.input_dim, params.hidden_dim, bidirectional=True, batch_first=True)
self.fc = nn.Linear(params.hidden_dim * 2, params.number_of_tags)
self.char_embedding_dim = params.char_embedding_dim
def forward(self, s):
def get_char_representation(char_seq):
s_ch_e = self.char_embedding(char_seq)
char_vec = torch.empty(s_ch_e.shape[0], s_ch_e.shape[1], s_ch_e.shape[3] * 2)
for idx, ch in enumerate(s_ch_e):
s_ch_rep, _ = self.c_lstm(ch)
s_ch_rep_f = s_ch_rep[:, -1, 0: self.char_embedding_dim]
s_ch_rep_b = s_ch_rep[:, 0, self.char_embedding_dim:]
s_ch_rep = torch.cat((s_ch_rep_f, s_ch_rep_b), dim=1)
char_vec[idx] = s_ch_rep
return char_vec
s_words, s_pos, s_cap, s_chars = s
s_w = self.word_embedding(s_words)
s_p = self.pos_embedding(s_pos)
s_c = self.cap_embedding(s_cap)
s_ch = get_char_representation(s_chars)
s = torch.cat((s_w, s_p, s_c, s_ch), dim=2)
s, _ = self.lstm(s)
s = s.contiguous()
s = s.view(-1, s.shape[2])
s = self.fc(s)
return f.log_softmax(s, dim=1)