Implementation: Augmenting the LSTM part-of-speech tagger with character-level features

I was following the excellent tutorials on pytorch’s website. I modified the code for An LSTM for Part-of-Speech Tagging to implement the exercise which requires to add another LSTM to get char level representation of words concatenate it with word embedding and train for learning tags.

My network code is as follows:

    class LSTMTaggerWithChar(nn.Module):
        def __init__(self, embedding_dim, hidden_dim, embedding_dim_char, hidden_dim_char, vocab_size, vocab_size_char, target_size):
            super(LSTMTaggerWithChar, self).__init__()
            self.hidden_dim = hidden_dim
            self.hidden_dim_char = hidden_dim_char
            
            self.embedding_char = nn.Embedding(vocab_size_char, embedding_dim_char)
            self.lstm_char = nn.LSTM(embedding_dim_char, hidden_dim_char)
            
            self.embedding = nn.Embedding(vocab_size, embedding_dim)
            self.lstm = nn.LSTM(embedding_dim, hidden_dim)
            self.hidden2tag = nn.Linear(hidden_dim, target_size)
            
            self.hidden = self.init_hidden()
            self.hidden_char = self.init_hidden_char()
            
        def init_hidden(self):
            return (autograd.Variable(torch.zeros(1, 1, self.hidden_dim)), autograd.Variable(torch.zeros(1, 1, self.hidden_dim)))
        
        def init_hidden_char(self):
            return (autograd.Variable(torch.zeros(1, 1, self.hidden_dim_char)), autograd.Variable(torch.zeros(1, 1, self.hidden_dim_char)))
            
        def forward(self, sentence, words):
            
            for ix, word in enumerate(sentence):
                chars = words[ix]
    #             self.hidden_char = self.init_hidden_char() Should I re-initialize hidden_char tensor here?
                char_embeds = self.embedding_char(chars).view(len(chars), 1, -1)
                lstm_char_out, self.hidden_char = self.lstm_char(char_embeds, self.hidden_char)
                
                char_rep = lstm_char_out[-1]

                embeds = self.embedding(word).view(1, 1, -1)
                lstm_out, self.hidden = self.lstm(embeds, self.hidden)
                
                tag_score = F.log_softmax(self.hidden2tag(lstm_out.view(1, -1)))
                
                if ix == 0:
                    tag_scores = tag_score
                else:
                    tag_scores = torch.cat((tag_scores, tag_score), 0)
            return tag_scores
  1. Here even if I uncomment the line self.hidden_char = self.init_hidden_char() in forward function, I get the same results. I don’t understand why this should happen?

  2. Also I think that character LSTM’s hidden state should be reset after it spits out representation for a word, assuming representation of next word is unrelated to previous word. But if I do that I am not clear how back-propagation will happen for the character LSTM on calling loss.backward() upon training?

  3. Is it OK to have for loops in the forward function? How can it be avoided?

2 Likes

I am learning Pytorch by going through the tutorial as well. I tried to implement the character-level last for pos and the code is as below. Plz correct me if you find something not reasonable.

# -*- coding: utf-8 -*-

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)


######################################################################
# Example: An LSTM for Part-of-Speech Tagging
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Prepare data:

def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    tensor = torch.LongTensor(idxs)
    return autograd.Variable(tensor)

training_data = [
    ("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
    ("Everybody read that book".split(), ["NN", "V", "DET", "NN"])
]
word_to_ix = {}
char_to_ix = {}
MAX_WORD_LEN = 0
for sent, tags in training_data:
    for word in sent:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
            for ch in word:
                if ch not in char_to_ix:
                    char_to_ix[ch] = len(char_to_ix)
            if len(word) > MAX_WORD_LEN:
                MAX_WORD_LEN = len(word)
char_to_ix[' '] = len(char_to_ix)
print(word_to_ix)
print(char_to_ix)
print MAX_WORD_LEN
tag_to_ix = {"DET": 0, "NN": 1, "V": 2}

# These will usually be more like 32 or 64 dimensional.
# We will keep them small, so we can see how the weights change as we train.
EMBEDDING_DIM = 16
HIDDEN_DIM = 16
CHAR_EMBEDDING_DIM = 3
CHAR_HIDDEN_DIM = 3

######################################################################
# Create the model:


class LSTMTagger(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, char_hidden_dim, char_embedding_dim, alphabet_size, max_word_len, tagset_size):
        super(LSTMTagger, self).__init__()

        # word embedding
        self.hidden_dim = hidden_dim
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)

        # char embedding
        self.char_hidden_dim = char_hidden_dim
        self.char_embeddings = nn.Embedding(alphabet_size, char_embedding_dim)
        self.char_lstm = nn.LSTM(char_embedding_dim, char_hidden_dim)

        self.overall_hidden_dim = hidden_dim + max_word_len * char_hidden_dim

        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(self.overall_hidden_dim, tagset_size)
        self.hidden = self.init_hidden()
        self.char_hidden = self.init_hidden(isChar=True)

    def init_hidden(self, isChar=False):
        # Before we've done anything, we dont have any hidden state.
        # Refer to the Pytorch documentation to see exactly
        # why they have this dimensionality.
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        if isChar:
            return (autograd.Variable(torch.zeros(1, 1, self.char_hidden_dim)),
                autograd.Variable(torch.zeros(1, 1, self.char_hidden_dim)))
        else:
            return (autograd.Variable(torch.zeros(1, 1, self.hidden_dim)),
                autograd.Variable(torch.zeros(1, 1, self.hidden_dim)))

    def forward(self, sentence, chars):
        embeds = self.word_embeddings(sentence)
        # print 'LEN SENTENCE', len(sentence)
        # print 'HIDDEN', self.hidden
        lstm_out, self.hidden = self.lstm(
            embeds.view(len(sentence), 1, -1), self.hidden)

        embedc = self.char_embeddings(chars)
        char_lstm_out, self.char_hidden = self.char_lstm(embedc.view(len(chars), 1, -1), self.char_hidden)

        merge_out = torch.cat((lstm_out.view(len(sentence), -1), char_lstm_out.view(len(sentence), -1)), 1)

        tag_space = self.hidden2tag(merge_out)
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores

######################################################################
# Train the model:


model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), CHAR_EMBEDDING_DIM, CHAR_HIDDEN_DIM, len(char_to_ix), MAX_WORD_LEN, len(tag_to_ix))
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# See what the scores are before training
# Note that element i,j of the output is the score for tag j for word i.
inputs = prepare_sequence(training_data[0][0], word_to_ix)
sent_chars = []
for w in training_data[0][0]:
    sps = ' ' * (MAX_WORD_LEN - len(w))
    sent_chars.extend(list(sps + w))
inputc = prepare_sequence(sent_chars, char_to_ix)
tag_scores = model(inputs, inputc)
print(tag_scores)

for epoch in range(300):  # again, normally you would NOT do 300 epochs, it is toy data
    for sentence, tags in training_data:
        # Step 1. Remember that Pytorch accumulates gradients.
        # We need to clear them out before each instance
        model.zero_grad()

        # Also, we need to clear out the hidden state of the LSTM,
        # detaching it from its history on the last instance.
        model.hidden = model.init_hidden()
        model.char_hidden = model.init_hidden(isChar=True)

        # Step 2. Get our inputs ready for the network, that is, turn them into
        # Variables of word indices.
        sentence_in = prepare_sequence(sentence, word_to_ix)
        sent_chars = []
        for w in sentence:
            sps = ' ' * (MAX_WORD_LEN - len(w))
            sent_chars.extend(list(sps + w))
        char_in = prepare_sequence(sent_chars, char_to_ix)
        targets = prepare_sequence(tags, tag_to_ix)

        # Step 3. Run our forward pass.
        tag_scores = model(sentence_in, char_in)

        # Step 4. Compute the loss, gradients, and update the parameters by
        #  calling optimizer.step()
        loss = loss_function(tag_scores, targets)
        loss.backward()
        optimizer.step()

# See what the scores are after training
inputs = prepare_sequence(training_data[0][0], word_to_ix)
sent_chars = []
for w in training_data[0][0]:
    sps = ' ' * (MAX_WORD_LEN - len(w))
    sent_chars.extend(list(sps + w))
inputc = prepare_sequence(sent_chars, char_to_ix)
tag_scores = model(inputs, inputc)
# The sentence is "the dog ate the apple".  i,j corresponds to score for tag j
#  for word i. The predicted tag is the maximum scoring tag.
# Here, we can see the predicted sequence below is 0 1 2 0 1
# since 0 is index of the maximum value of row 1,
# 1 is the index of maximum value of row 2, etc.
# Which is DET NOUN VERB DET NOUN, the correct sequence!
print(tag_scores)

The results looks like this:

Variable containing:
-0.0829 -2.7836 -4.0300
-6.9329 -0.0083 -4.9270
-3.9040 -3.5350 -0.0506
-0.0214 -4.8225 -4.3353
-4.4914 -0.0152 -5.5591
[torch.FloatTensor of size 5x3]

It looks correct:

 The(0:DET) dog(1:NN) ate(2:V) the(0:DET) apple(1:NN)

The key is to concat the two hidden tensors before they are fed into hidden2tag layer.

1 Like

I did not look at your code thoroughly but it seems like your code doesn’t work the way the tutorial requires. It says “to get the character level representation, do an LSTM over the characters of a word, and let character-level representation of the word the final hidden state of this LSTM” but I could not find a line for this.

Also, these two lines don’t make sense for me since you reshaped char_lstm_out with len(sentence). The sequence length of a sentence by words and by characters are different.

Did you ever figure out the answer to your questions? I converged to almost the same code and have the same doubts.
Also, I noticed that my implementation became incredibly slow after adding the char lstm. Did that happen to you as well?

@silpara it seems you forgot to concatenate the character representation of words with word embeddings as an input to the 2nd LSTM in your code? in __init__, you should have:

self.lstm = nn.LSTM(embedding_dim + hidden_dim_char, hidden_dim)

And then in forward:

lstm_out, self.hidden = self.lstm(torch.cat([embeds, char_rep], dim=2), self.hidden)
# To effectively have a tensor oh shape (1,1,embedding_dim + hidden_dim_char) as input here

Now to answer your questions, I would say:

  1. Maybe because char_rep was not used? Not so sure why otherwise…
  2. I agree with the reset, but resetting the hidden states do not change how gradients are accumulated at each .backward() call, so you still get a training here, if, for instance, you update your gradients between each words.
  3. For a very large model running on GPU, I would say that a for loop in the forward method is problematic and will slow down computation… Though I am not sure about that. But for safety, this can be avoided by looping on words outside the call to model(sentence, words) (and in that case you would modify your code such that forward takes one word at a time?)

@rgalhama I think it makes sense that it is slower, before we had one LSTM for sequence of words, and now for each of these words another LSTM computes each sequence of characters… I observed 5 to 10x slower…


Here is my code for this. I tried to get two versions, one following same training schedule as in the tutorial example, i.e. inputting entire sentence at once and another one working per word.

The advantage of having a whole sentence processed is to put the sequence at once, which is faster than repeating the same operation len(sequence) times. However, I wanted to avoid looping on words within the forward method, so I had to pack the list of character sequences somehow which also means I had to give up on resetting char_lstm's hidden states between words (however this is not true for the word by word version).

The problem is that the list of chars has variable lengths (e.g. for “Everybody read that book” >> (9,4,4,4)), and one way I found to handle this in PyTorch was to use torch.nn.utils.rnn.pack_sequence (from version 0.4, current master on Github).

The two versions are embedded in the same class LSTMCharTagger, the classic forward method implements the version working on whole sentences, with PackedSequence and no resetting of char_lstm's hidden states between words; then forward_one_word implements the second version where I train and accumulate gradients at each words (Note: in both case I zero out the gradients and update the parameters at the sentence level).

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from time import time

torch.manual_seed(1)

def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    tensor = torch.LongTensor(idxs)
    return autograd.Variable(tensor)

training_data = [
    ("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
    ("Everybody read that book".split(), ["NN", "V", "DET", "NN"])
]
word_to_ix = {}
for sent, tags in training_data:
    for word in sent:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
print(word_to_ix)
tag_to_ix = {"DET": 0, "NN": 1, "V": 2}
char_to_ix = {}
for sent,_ in training_data:
    for w in sent:
        for char in w:
            if char not in char_to_ix:
                char_to_ix[char] = len(char_to_ix)

EMBEDDING_DIM = 6
HIDDEN_DIM = 6
CHAR_EMBEDDING = 3
CHAR_LEVEL_REPRESENTATION_DIM = 3

def prepare_both_sequences(sentence, word_to_ix, char_to_ix):
    chars = [prepare_sequence(w, char_to_ix) for w in sentence]
    return prepare_sequence(sentence, word_to_ix), chars

class LSTMCharTagger(nn.Module):
    '''
    Augmented model, takes both sequence of words and char to predict tag.
    Characters are embedded and then get their own representation for each WORD.
    It is this representation that is merged with word embeddings and then fed to the sequence
    LSTM which decodes the tags.
    '''
    def __init__(self, word_embedding_dim, char_embedding_dim, hidden_dim,
                 hidden_char_dim, vocab_size, charset_size, tagset_size):
        super(LSTMCharTagger, self).__init__()
        self.hidden_dim = hidden_dim
        self.hidden_char_dim = hidden_char_dim

        # Word embedding:
        self.word_embedding = nn.Embedding(vocab_size, word_embedding_dim)

        # Char embedding and encoding into char-lvl representation of words (c_w):
        self.char_embedding = nn.Embedding(charset_size, char_embedding_dim)
        self.char_lstm = nn.LSTM(char_embedding_dim, hidden_char_dim)

        # Sequence model:
        self.lstm = nn.LSTM(word_embedding_dim + hidden_char_dim, hidden_dim)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

        # Init hidden state for lstms
        self.hidden = self.init_hidden(self.hidden_dim)
        self.hidden_char = self.init_hidden(self.hidden_char_dim)

    def init_hidden(self, size, batch_size=1):
        "Batch size argument used when PackedSequence are used"
        return (autograd.Variable(torch.zeros(1, batch_size, size)),
                autograd.Variable(torch.zeros(1, batch_size, size)))

    def forward_one_word(self, word_sequence, char_sequence):
        ''' For a word by word processing.
        '''
        # Word Embedding
        word_embeds = self.word_embedding(word_sequence)
        # Char lvl representation of each words with 1st LSTM
        char_embeds = self.char_embedding(char_sequence)
        char_lvl, self.hidden_char = self.char_lstm(char_embeds.view(len(char_sequence),1,-1), self.hidden_char)
        # Merge
        merged = torch.cat([word_embeds.view(1,1,-1), char_lvl[-1].view(1,1,-1)], dim=2)
        # Predict tag with 2nd LSTM:
        lstm_out, self.hidden = self.lstm(merged, self.hidden)
        tag_space = self.hidden2tag(lstm_out.view(1, -1))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores

    def forward(self, word_sequence, char_sequence):
        ''' Importantly, char_sequence is a list of tensors, one per word, and one tensor 
        must represent a whole sequence of character for a given word.
        E.g.: is word_sequence has length 4, char_seq must be of length 4, thus char_lstm
        will output 4 char-level word representations (c_w).

        Here we deal with variable lengths of character tensors sequence using nn.utils.rnn.pack_sequence
        '''
        # Word Embedding
        word_embeds = self.word_embedding(word_sequence)

        # Char lvl representation of each words with 1st LSTM
        # We will pack variable length embeddings in PackedSequence. Must sort by decreasing length first.
        sorted_length = np.argsort([char_sequence[k].size()[0] for k in range(len(char_sequence))])
        sorted_length = sorted_length[::-1] # decreasing order
        char_embeds = [self.char_embedding(char_sequence[k]) for k in sorted_length]
        packed = nn.utils.rnn.pack_sequence(char_embeds) # pack variable length sequence
        out, self.hidden_char = self.char_lstm(packed, self.hidden_char)
        encodings_unpacked, seqlengths = nn.utils.rnn.pad_packed_sequence(out, batch_first=True) # unpack and pad
        # We need to take only last element in sequence of lstm char output for each word:
        unsort_list = np.argsort(sorted_length) # indices to put list of encodings in orginal word order
        char_lvl = torch.stack([encodings_unpacked[k][seqlengths[k]-1] for k in unsort_list])

        # Merge
        merged = torch.cat([word_embeds, char_lvl], dim=1) # gives tensor of size (#words, #concatenated features)

        # Predict tag with 2nd LSTM:
        lstm_out, self.hidden = self.lstm(merged.view(len(word_sequence), 1, -1), self.hidden)
        tag_space = self.hidden2tag(lstm_out.view(len(word_sequence), -1))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores

def get_batch_size(seq2pack):
    "Need this to correctly initialize batch lstm hidden states when packing variable length sequences..."
    sorted_length = np.argsort([seq2pack[k].size()[0] for k in range(len(seq2pack))])
    sorted_length = sorted_length[::-1] # decreasing order
    packed = nn.utils.rnn.pack_sequence([seq2pack[k] for k in sorted_length]) 
    return max(packed.batch_sizes)

model = LSTMCharTagger(EMBEDDING_DIM, CHAR_EMBEDDING, HIDDEN_DIM, CHAR_LEVEL_REPRESENTATION_DIM,
                       len(word_to_ix), len(char_to_ix), len(tag_to_ix))
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# See what the scores are before training
words_in, chars_in = prepare_both_sequences(training_data[0][0], word_to_ix, char_to_ix)
model.hidden_char = model.init_hidden(model.hidden_char_dim, batch_size=get_batch_size(chars_in))
tag_score = model(words_in, chars_in)
print(tag_score)

t0 = time()
for epoch in range(300): 
    for sentence, tags in training_data:
        # Step 1. Remember that Pytorch accumulates gradients.
        model.zero_grad()

        # Step 2. Get our inputs ready
        sentence_in, chars_in = prepare_both_sequences(sentence, word_to_ix, char_to_ix)
        targets = prepare_sequence(tags, tag_to_ix)
        model.hidden = model.init_hidden(model.hidden_dim)
        model.hidden_char = model.init_hidden(model.hidden_char_dim, batch_size=get_batch_size(chars_in))

        # Step 3. Run our forward pass.
        tag_score = model(sentence_in, chars_in)

        # Step 4. Compute the loss, gradients, and update the parameters
        loss = loss_function(tag_score, targets)
        loss.backward()
        optimizer.step()
print("300 epochs in %.2f sec for model with packed sequences"%(time()-t0))

model = LSTMCharTagger(EMBEDDING_DIM, CHAR_EMBEDDING, HIDDEN_DIM, CHAR_LEVEL_REPRESENTATION_DIM,
                       len(word_to_ix), len(char_to_ix), len(tag_to_ix))
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

t0 = time()
for epoch in range(300):
    for sentence, tags in training_data:
        sentence_score = []
        # Step 1. Remember that Pytorch accumulates gradients.
        model.zero_grad()

        # Step 2. Get our inputs ready
        sentence_in, chars_in = prepare_both_sequences(sentence, word_to_ix, char_to_ix)
        targets = prepare_sequence(tags, tag_to_ix)
        model.hidden = model.init_hidden(model.hidden_dim)
        #model.hidden_char = model.init_hidden(model.hidden_char_dim)

        # Step 3. Run our forward pass on each word
        for k in range(len(sentence)):
            # Clear hidden state between EACH word (char level representation must be independent of previous word)
            model.hidden_char = model.init_hidden(model.hidden_char_dim)
            tag_score = model.forward_one_word(sentence_in[k], chars_in[k])
            sentence_score.append(tag_score)
            loss = loss_function(tag_score, targets[k].view(1,))
            loss.backward(retain_graph=True) # accumulate gradients now
            #tag_score = autograd.Variable(torch.cat(sentence_score), requires_grad=True)

        # Step 4. Update parameters at the end of sentence
        optimizer.step()
print("300 epochs in %.2f sec for model at word level"%(time()-t0))

# See what the scores are after training
words_in, chars_in = prepare_both_sequences(training_data[0][0], word_to_ix, char_to_ix)
model.hidden_char = model.init_hidden(model.hidden_char_dim, batch_size=get_batch_size(chars_in))
tag_score = model(words_in, chars_in)
print(tag_score)
2 Likes

@Hugo-W

I have looked at all of your code but I can’t understand this code

merged = torch.cat([word_embeds.view(1,1,-1), char_lvl[-1].view(1,1,-1)], dim=2) in function forward_one_word

Here is my question:

  1. Why use char_lvl[-1]? Is it the last char in char_lvl ?
  2. Should we train model with words and the affix?

Thx and looking forward to your answer

It’s a shame I did not reuse my own code since that post… So right now I don’t have it in mind anymore, I had to read it probably as you did to understand what I might have done (plus I am not testing it as I am writing). I would advise you to run the code line by line, and also run the lines within that function forward_one_word and see what is inside char_lvl[-1].view(1,1,-1) in comparison to char_lvl

My guess (again sorry I cannot verify the code right now) is that I do take the last word representation at the character level, since I think the input is the list of char in the list of words of a given sentence. So basically it’s just a way to merge representation of one word (woird embedding + its char level representation) together…

Do you mean Part of speech tag (the labels in this task) by affix? But yes you train the model with words as input and target POS tag as output, the training procedure is also in the code I posted. I wrap inputs and outputs in corresponding sequences with prepare_sequence:

sentence_in, chars_in = prepare_both_sequences(sentence, word_to_ix, 
targets = prepare_sequence(tags, tag_to_ix)