Lm cannot train


(Eric Jiang) #1

Based on example code. I change some code to solve big data problem.
Instead push all data in memory then turn to gpu, I use dataloader to generate batch data, and control vocab size less than 50000.
dataloder code is this:

`
class ContLMDataset(Dataset):
“”"dataset that cat the sentences into one long sequence and chunk

 Each training sample is a chunked version and of same length.

    Attributes:
        - vocab: Vocab object which holds the vocabulary info
        - file_path: the directory of all train, test and valid corpus
        - bptt: sequence length
"""

def __init__(self, file_path, vocab=None, bptt=35):
    super(ContLMDataset, self).__init__()
    self.vocab = vocab
    self.file_path = file_path
    self.bptt = bptt
    self.tokenize(file_path)


def tokenize(self, path):
    """Tokenizes a text file."""
    assert os.path.exists(path)
    # add the start of sentence token
    sentence_sep = [BOS]
    with open(path, 'r') as f:
        sentences = [BOS]
        for sentence in tqdm(f, desc='Processing file: {}'.format(path)):
            sentences += sentence.split() + sentence_sep
    # split into list of tokens
    self.data = sentences

def __getitem__(self, index):
    data = self.data[index * self.bptt:(index + 1) * self.bptt]
    target = self.data[index * self.bptt:(index + 1) * self.bptt]
    return [self.vocab.word2idx[word] for word in data],  [self.vocab.word2idx[word] for word in target]

def __len__(self):
    return len(self.data) // self.bptt


def pad_collate_fn(batch):
    data = [data for data, targets in batch]
    targets = [targets for data, targets in batch]
    return torch.LongTensor(data), torch.LongTensor(targets)


class Corpus(object):
      def __init__(self, path, vocab_path=None, batch_size=1, shuffle=False,
                 pin_memory=False, update_vocab=False, min_freq=1,
                 concat=False, bptt=35):
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.pin_memory = pin_memory
        self.base_path = path
        self.update_vocab = update_vocab
        self.bptt = bptt
        self.concat = concat

        self.vocab = get_vocab(path, ['train.txt'], min_freq=min_freq, vocab_file=vocab_path)
        if self.concat:
            # set the frequencies for special tokens by miracle trial
            self.vocab.idx2count[1] = self.vocab.freqs[BOS]  # <s>
            self.vocab.idx2count[2] = 0  # </s>

        self.train = self.get_dataloader('train.txt', self.batch_size)
        self.valid = self.get_dataloader('valid.txt', 10)
        self.test = self.get_dataloader('test.txt', 1)

def get_dataloader(self, filename, bs=1):
    full_path = os.path.join(self.base_path, filename)
    dataset = ContLMDataset(full_path, vocab=self.vocab, bptt=self.bptt)
    return DataLoader(
        dataset=dataset,
        batch_size=bs,
        shuffle=self.shuffle,
        pin_memory=self.pin_memory,
        collate_fn=pad_collate_fn,
        # num_workers=1,
        # waiting for a new torch version to support
        drop_last=True,

`

vocab code is :

    """Build the vocabulary from corpus

    This file is forked from pytorch/text repo at Github.com"""
    import os
    import dill as pickle
    import logging
    from collections import defaultdict, Counter

    from tqdm import tqdm
    logger = logging.getLogger(__name__)

    UNK = '<unk>'  # unknown word
    BOS = '<s>'  # sentence start
    EOS = '</s>'  # sentence end

    def _default_unk_index():
        return 0


    def load_freq(freq_file):
        """Load the frequency from text file"""
        counter = Counter()
        with open(freq_file) as f:
        for line in f:
                word, freq = line.split(' ')
                counter[word] = freq
        return counter


    def write_freq(counter, freq_file):
        """Write the word-frequency pairs into text file

        File format:

        word1 freq1
        word2 freq2

    """
        # sort by frequency, then alphabetically
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
        with open(freq_file, 'w') as f:
            for word, freq in words_and_frequencies:
                f.writelines('{} {}\n'.format(word, freq))



class Vocab(object):
    """Defines a vocabulary object that will be used to numericalize a field.
    Attributes:
        freqs: A collections.Counter object holding the frequencies of tokens
            in the data used to build the Vocab.
        word2idx: A collections.defaultdict instance mapping token strings to
            numerical identifiers.
        idx2word: A list of token strings indexed by their numerical identifiers.
    """
    def __init__(self, counter, max_size=None, min_freq=1):
        """Create a Vocab object from a collections.Counter.
        Arguments:
            counter: collections.Counter object holding the frequencies of
                each value found in the data.
            max_size: The maximum size of the vocabulary, or None for no
                maximum. Default: None.
            min_freq: The minimum frequency needed to include a token in the
                vocabulary. Values less than 1 will be set to 1. Default: 1.
        """
        self.freqs = counter
        self.max_size = max_size
        self.min_freq = min_freq
        self.specials = [UNK, BOS, EOS]
        self.build()


    def build(self, force_vocab=[]):
        """Build the required vocabulary according to attributes

        We need an explicit <unk> for NCE because this improve the precision of
        word frequency estimation in noise sampling

        Args:
            - force_vocab: force the vocabulary to be within this vocab
        """
        counter = self.freqs.copy()
        if force_vocab:
            min_freq = 1
        min_freq = max(self.min_freq, 1)

        # delete the special tokens from given vocabulary
        force_vocab = [w for w in force_vocab if w not in self.specials]

        # Do not count the BOS and UNK as frequency term
        for word in self.specials:
            del counter[word]

        self.idx2word = self.specials + force_vocab
        max_size = None if self.max_size is None else self.max_size + len(self.idx2word)

        # sort by frequency, then alphabetically
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)

        unk_freq = 0
        for word, freq in words_and_frequencies:

            # for words not in force_vocab and with freq<th, throw to <unk>
            if freq < min_freq and word not in force_vocab:
                unk_freq += freq
            elif len(self.idx2word) != max_size and not force_vocab:
                self.idx2word.append(word)

        self.word2idx = defaultdict(_default_unk_index)
        self.word2idx.update({
            word: idx for idx, word in enumerate(self.idx2word)
        })

        self.idx2count = [self.freqs[word] for word in self.idx2word]
        # set the frequencies for special tokens by miracle trial
        self.idx2count[0] += unk_freq  # <unk>
        self.idx2count[1] = 0  # <s>
        self.idx2count[2] = self.freqs['</s>']  # </s>

    def __eq__(self, other):
        if self.freqs != other.freqs:
            return False
        if self.word2idx != other.word2idx:
            return False
        if self.idx2word != other.idx2word:
            return False
        return True

    def __len__(self):
        return len(self.idx2word)

    def extend(self, v, sort=False):
        words = sorted(v.idx2word) if sort else v.idx2word
        # TODO: speedup the dependency
        for w in words:
            if w not in self.word2idx:
                self.idx2word.append(w)
                self.word2idx[w] = len(self.idx2word) - 1


def check_vocab(vocab):
    """A util function to check the vocabulary correctness"""
    # one word for one index
    assert len(vocab.idx2word) == len(vocab.word2idx)

    # no duplicate words in idx2word
    assert len(set(vocab.idx2word)) == len(vocab.idx2word)


def get_vocab(base_path, file_list, min_freq=1, force_recount=False, vocab_file=None):
    """Build vocabulary file with each line the word and frequency

    The vocabulary object is cached at the first build, aiming at reducing
    the time cost for pre-process during training large dataset

    Args:
        - sentences: sentences with BOS and EOS
        - min_freq: minimal frequency to truncate
        - force_recount: force a re-count of word frequency regardless of the
        Count cache file
        - vocab_file: a specific vocabulary file. If not None, the returned
        vocabulary will only count the words in vocab_file, with others treated
        as <unk>

    Return:
        - vocab: the Vocab object
    """
    counter = Counter()
    cache_file = os.path.join(base_path, 'vocab.pkl')

    if os.path.exists(cache_file) and not force_recount:
        logger.debug('Load cached vocabulary object')
        vocab = pickle.load(open(cache_file, 'rb'))
        if min_freq:
            vocab.min_freq = min_freq
        logger.debug('Load cached vocabulary object finished')
    else:
        logger.debug('Refreshing vocabulary')
        for filename in file_list:
            full_path = os.path.join(base_path, filename)
            for line in tqdm(open(full_path, 'r'), desc='Building vocabulary: '):
                counter.update(line.split())
                counter.update([BOS, EOS])
        vocab = Vocab(counter, min_freq=min_freq)
        logger.debug('Refreshing vocabulary finished')

        # saving for future uses
        freq_file = os.path.join(base_path, 'freq.txt')
        write_freq(vocab.freqs, freq_file)
        pickle.dump(vocab, open(cache_file, 'wb'))

    force_vocab = []
    if vocab_file:
        with open(vocab_file) as f:
            force_vocab = [line.strip() for line in f]
    vocab.build(force_vocab=force_vocab)
    check_vocab(vocab)
    return vocab

Then I add optimizer adam. and change optimizer.zero_grad() and optimizer.step().

After do this, lm model can not train. val loss near 7, basic don’t change.


#2

Could you post the training code?
It would be interesting to see where you call your optimizer. Maybe it’s just a simple mistake.


(Eric Jiang) #3
def train(model, data_source, epoch, lr=1.0, weight_decay=1e-5, momentum=0.9):
    optimizer = optim.Adam(
        params=model.parameters(),
        lr=lr,
        # momentum=momentum,
        weight_decay=weight_decay
    )
    # Turn on training mode which enables dropout.
    model.train()
    model.criterion.nce = args.nce
    total_loss = 0
    pbar = tqdm(data_source, desc='Training PPL: ....')
    for num_batch, data_batch in enumerate(pbar):
        optimizer.zero_grad()
        data, target, length = process_data(data_batch, cuda=args.cuda, sep_target=sep_target)
        loss = model(data, target, length)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        total_loss += loss.item()

        if args.prof:
            break
        if num_batch % args.log_interval == 0 and num_batch > 0:
            cur_loss = total_loss / args.log_interval
            ppl = math.exp(cur_loss)
            logger.debug(
                '| epoch {:3d} | {:5d}/{:5d} batches '
                '| lr {:02.2f} | loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, num_batch, len(corpus.train),
                    lr, cur_loss, ppl
                  )
            )
            pbar.set_description('Training PPL %.1f' % ppl)
            total_loss = 0

#4

It looks like your loss is calculated in the forward function of your model.
Could you post this code, so that we can make sure for example that you are not detaching any tensors?
Also, you could try to use just one single sample and try to overfit your model on it.
If the loss still doesn’t move at all, there is most likely a bug in your code.