I am trying to augment Char Embeddings with Word Embeddings to enhance my NER model, the problem is the code is super slow (locally) and it crashes Colab every time, any ideas on how to improve it (performance-wise)
Here is my code
class DualTagger(nn.Module):
def __init__(self, hparams):
super(DualTagger, self).__init__()
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.word_embedding = nn.Embedding(hparams.vocab_size, hparams.embedding_dim)
self.char_embedding = nn.Embedding(hparams.char_vocab_size, hparams.char_embedding_dim)
self.char_lstm = nn.LSTM(hparams.char_embedding_dim, hparams.char_hidden_dim)
self.lstm = nn.LSTM(hparams.embedding_dim + hparams.char_hidden_dim,
hparams.hidden_dim)
# self.dropout = nn.Dropout(hparams.dropout)
self.hidden2tag = nn.Linear(hparams.hidden_dim, hparams.num_classes)
self.hidden_char = (autograd.Variable(torch.zeros(1, 1, hparams.char_hidden_dim)).to(self._device),
autograd.Variable(torch.zeros(1, 1, hparams.char_hidden_dim)).to(self._device))
self.hidden_words = (autograd.Variable(torch.zeros(1, 1, hparams.hidden_dim)).to(self._device),
autograd.Variable(torch.zeros(1, 1, hparams.hidden_dim)).to(self._device))
def forward(self, sentence, word_tensors_):
logits_ = None
for idx, word_idx in enumerate(sentence):
word_chars_tensor = word_tensors_.get(int(word_idx))
char_embeds = self.char_embedding(word_chars_tensor)
lstm_char_out, self.hidden_char = self.char_lstm(char_embeds.view(len(word_chars_tensor), 1, -1), self.hidden_char)
word_embed = self.word_embedding(word_idx)
embeds_cat = torch.cat((word_embed.view(1, 1, -1), lstm_char_out[-1].view(1, 1, -1)), dim=2)
lstm_out, self.hidden_words = self.lstm(embeds_cat, self.hidden_words)
logits = self.hidden2tag(lstm_out.view(1, -1))
if idx == 0:
logits_ = logits
else:
logits_ = torch.cat((logits_, logits), dim=0)
return logits_
class HyperParameters():
def __init__(self, model_name_, vocab, char_vocab, label_vocab, embeddings_, batch_size_):
self.model_name = model_name_
self.char_vocab_size, self.vocab_size = len(char_vocab), len(vocab)
self.hidden_dim, self.char_hidden_dim = 256, 256
self.embedding_dim, self.char_embedding_dim = 300, 300
self.num_classes = len(label_vocab)
self.bidirectional = False
self.num_layers = 1
self.dropout = 0.4
self.embeddings = embeddings_
self.batch_size = batch_size_
class Trainer(object):
def __init__(self, model, loss_function, optimizer, verbose):
self.model = model
self.loss_function = loss_function
self.optimizer = optimizer
self._verbose = verbose
def train(self, train_dataset: Dataset, valid_dataset: Dataset, epochs: int = 1, word_tensors=None):
train_loss = 0.0
for epoch in tqdm(range(epochs), desc="Training Epochs"):
epoch_loss = 0.0
self.model.train()
for step, sample in tqdm(enumerate(train_dataset), desc='Training Batches'):
inputs = sample['inputs']
labels = sample['outputs']
self.optimizer.zero_grad()
for input_, tags in zip(tqdm(inputs), labels):
input_ = input_[input_.nonzero()]
predictions = self.model(input_, word_tensors)
predictions = predictions.view(-1, predictions.shape[-1])
tags = tags[tags.nonzero()].view(-1)
sample_loss = self.loss_function(predictions, tags)
sample_loss.backward(retain_graph=True)
train_loss += sample_loss.item()
clip_grad_norm_(self.model.parameters(), 5.) # Gradient Clipping
self.optimizer.step()
epoch_loss += sample_loss.tolist()
avg_epoch_loss = epoch_loss / len(train_dataset)
train_loss += avg_epoch_loss
valid_loss = self.evaluate(valid_dataset, word_tensors)
if self._verbose > 0:
print(f'Epoch {epoch}: [loss = {avg_epoch_loss:0.4f}, val_loss = {valid_loss:0.4f}]')
avg_epoch_loss = train_loss / epochs
return avg_epoch_loss
def evaluate(self, valid_dataset, word_tensors):
valid_loss = 0.0
self.model.eval()
with torch.no_grad():
for sample in valid_dataset:
inputs = sample['inputs']
labels = sample['outputs']
for input_, tags in zip(tqdm(inputs), labels):
input_ = input_[input_.nonzero()]
predictions = self.model(input_, word_tensors)
predictions = predictions.view(-1, predictions.shape[-1])
tags = tags[tags.nonzero()].view(-1)
sample_loss = self.loss_function(predictions, tags)
valid_loss += sample_loss.tolist()
return valid_loss / len(valid_dataset)
def predict(self, x):
self.model.eval()
with torch.no_grad():
logits = self.model(x)
predictions = torch.argmax(logits, -1)
return logits, predictions