Training is taking a very long time for my model


(Michael Li) #1

Hi all,

I am working on the Quora Questions dataset and have built the following SiameseLSTM model in PyTorch. I’m using a minibatch size of 64 and the Adam Optimizer with a learning rate of 0.002. I would appreciate if someone could take a look at my code and let me know if I am doing anything wrong. My model seems to be using too much memory as it basically freezes and stops training after anywhere from 10-20 batches. I think it may be a problem with how I call repackage_hidden() or zero_grad() on the optimizer. My code is pasted below:

EDIT (3/28): I added in the fixes to the below code and it seems to be running much faster now thanks to @ptrblck!

#!/usr/bin/env python
import bcolz
import csv
import pickle
import random
import re

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
try:
    from qhoptim.pyt import QHM, QHAdam
except:
    pass

import numpy as np

from torch.autograd import Variable

torch.manual_seed(1)

GLOVE_PATH = 'data'
# ADDL = '/home/wli24/nlpresearch/'
ADDL = ''

# Use the following for dept machine
TRAIN_DATA_PATH = "{}data/train.csv".format(ADDL)
TEST_DATA_PATH = "{}data/test.csv".format(ADDL)
STOPWORD_IDX = 0

EMBEDDING_DIM = 50
HIDDEN_DIM = 100
MAX_SENTENCE_LENGTH = 35
BATCH_SIZE = 64
MAX_DATASET_SIZE = float("inf")
NUM_LAYERS = 1

curr_idx = 1

dictionary = {"STOP_PAD": STOPWORD_IDX}

vectors = bcolz.open("{}{}/6B.{}.dat".format(ADDL, GLOVE_PATH, EMBEDDING_DIM))[:]
words = pickle.load(open("{}{}/6B.{}_words.pkl".format(ADDL, GLOVE_PATH, EMBEDDING_DIM), 'rb'))
word2idx = pickle.load(open("{}{}/6B.{}_idx.pkl".format(ADDL, GLOVE_PATH, EMBEDDING_DIM), 'rb'))

glove = {w: vectors[word2idx[w]] for w in words}

def normalize_text(text, max_len):
    ''' Pre process and convert texts to a list of words '''
    text = str(text)
    text = text.lower()

    # Clean the text
    text = re.sub(r"[^A-Za-z0-9^,!.\/'+-=]", " ", text)
    text = re.sub(r"what's", "what is ", text)
    text = re.sub(r"\'s", " ", text)
    text = re.sub(r"\'ve", " have ", text)
    text = re.sub(r"can't", "cannot ", text)
    text = re.sub(r"n't", " not ", text)
    text = re.sub(r"i'm", "i am ", text)
    text = re.sub(r"\'re", " are ", text)
    text = re.sub(r"\'d", " would ", text)
    text = re.sub(r"\'ll", " will ", text)
    text = re.sub(r",", " ", text)
    text = re.sub(r"\.", " ", text)
    text = re.sub(r"!", " ! ", text)
    text = re.sub(r"\/", " ", text)
    text = re.sub(r"\^", " ^ ", text)
    text = re.sub(r"\+", " + ", text)
    text = re.sub(r"\-", " - ", text)
    text = re.sub(r"\=", " = ", text)
    text = re.sub(r"'", " ", text)
    # text = re.sub(r"?", "", text)
    text = re.sub(r"(\d+)(k)", r"\g<1>000", text)
    text = re.sub(r":", " : ", text)
    text = re.sub(r" e g ", " eg ", text)
    text = re.sub(r" b g ", " bg ", text)
    text = re.sub(r" u s ", " american ", text)
    text = re.sub(r"\0s", "0", text)
    text = re.sub(r" 9 11 ", "911", text)
    text = re.sub(r"e - mail", "email", text)
    text = re.sub(r"j k", "jk", text)
    text = re.sub(r"\s{2,}", " ", text)

    text = text.split()

    if len(text) < max_len:
        text += ["STOP_PAD"] * (max_len - len(text))

    return text


def random_sequential_idxs(range_max):
    shuffled = list(range(range_max))
    random.shuffle(shuffled)
    return shuffled


def create_embedding_layer(weights_matrix, non_trainable=False):
    num_embeddings, embedding_dim = weights_matrix.shape
    emb_layer = nn.Embedding(num_embeddings, embedding_dim)
    emb_layer.load_state_dict({'weight': torch.tensor(weights_matrix, dtype=torch.float)})
    if non_trainable:
        emb_layer.weight.requires_grad = False

    return emb_layer, num_embeddings, embedding_dim


def load_train_data(train_path, max_len):
    result = []
    with open(train_path, mode='r', encoding="utf8") as csvfile:
        train_datareader = csv.reader(csvfile, delimiter=',')
        next(train_datareader, None)
        count = 0
        for row in train_datareader:
            if (count > MAX_DATASET_SIZE):
                return result
            count += 1
            # Row format
            # pair_id, q1_id, q2_id, q1_text, q2_text, is_duplicate
            q1, q2 = normalize_text(row[3], max_len), normalize_text(row[4], max_len)
            if len(q1) <= max_len and len(q2) <= max_len:
                result += [(q1, q2, int(row[5]))]

    return result


def vectorize(seq):
    global curr_idx
    # global dictionary
    idxs = []
    for i in seq:
        if dictionary.get(i) is None:
            dictionary[i] = curr_idx
            curr_idx += 1
        idxs += [dictionary[i]]

    return np.array(idxs)

def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

######################################################################
# Create the model:

class SiameseLSTM(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, batch_size, num_lstm_layers):
        super(SiameseLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        self.batch_size = batch_size
        self.num_lstm_layers = num_lstm_layers

        self.word_embeddings, _, _ = create_embedding_layer(weights_matrix, True)

        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=self.num_lstm_layers,
            batch_first=True
        )

        self.hidden1 = self.init_hidden()
        self.hidden2 = self.init_hidden()

        self.fc1 = nn.Sequential(
            nn.Linear(4 * self.hidden_dim + 2 * (MAX_SENTENCE_LENGTH * self.hidden_dim), 3 * self.hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(3 * self.hidden_dim, self.hidden_dim),
            nn.Linear(self.hidden_dim, 2),
            nn.Softmax(),
        )

    def init_hidden(self):
        # Try torch.randn instead of torch.zeros
        return (torch.zeros(self.num_lstm_layers, self.batch_size, self.hidden_dim),
                torch.zeros(self.num_lstm_layers, self.batch_size, self.hidden_dim))

    def forward(self, sentence1, sentence2):
        embeds1 = self.word_embeddings(sentence1.long())
        embeds2 = self.word_embeddings(sentence2.long())

        lstm1_out, self.hidden1 = self.lstm(embeds1, self.hidden1)
        lstm2_out, self.hidden2 = self.lstm(embeds2, self.hidden2)

        concat_all = torch.cat((
            lstm1_out.contiguous().view(self.batch_size, MAX_SENTENCE_LENGTH * self.hidden_dim),
            self.hidden1[0].view(self.batch_size, self.hidden_dim),
            self.hidden1[1].view(self.batch_size, self.hidden_dim),
            lstm2_out.contiguous().view(self.batch_size, MAX_SENTENCE_LENGTH * self.hidden_dim),
            self.hidden2[0].view(self.batch_size, self.hidden_dim),
            self.hidden2[1].view(self.batch_size, self.hidden_dim),
        ), dim=1)
        output = self.fc1(concat_all)
        return output

######################################################################
# Train the model:

print("## LOADING TRAINING DATA ##")

res = load_train_data(TRAIN_DATA_PATH, MAX_SENTENCE_LENGTH)
training_data = []
training_data_flipped = []

training_labels = []

total_num_examples = 0
print("## PROCESSING TRAINING DATA ##")

for r in res:
    r0_ndarray = vectorize(r[0])
    r1_ndarray = vectorize(r[1])

    if len(training_data) != 0:
        training_data += [r0_ndarray, r1_ndarray]
    else:
        training_data = [r0_ndarray, r1_ndarray]

    if len(training_data_flipped) != 0:
        training_data_flipped += [r1_ndarray, r0_ndarray]
    else:
        training_data_flipped = [r1_ndarray, r0_ndarray]

    total_num_examples += 2
    training_labels += [r[2], r[2]]

training_data = np.array(training_data)
training_data_flipped = np.array(training_data_flipped)
matrix_len = curr_idx
weights_matrix = np.zeros((matrix_len, EMBEDDING_DIM))

for key, val in dictionary.items():
    # Ignore the stop word, this should be given an embedding vector of all 0s
    if val == 0:
        continue
    try:
        weights_matrix[val] = glove[key]
    except KeyError:
        weights_matrix[val] = np.random.normal(scale=0.6, size=(EMBEDDING_DIM, ))

print("## MAKING MODEL ##")

model = SiameseLSTM(EMBEDDING_DIM, HIDDEN_DIM, BATCH_SIZE, NUM_LAYERS)
loss_function = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.002)
# optimizer = QHM(model.parameters(), lr=1.0, nu=0.7, momentum=0.999)
# optimizer = QHAdam(model.parameters(), lr=0.02, nus=(0.7, 1.0), betas=(0.995, 0.999))

for epoch in range(100):
    print("Epoch", epoch)
    correct = 0
    total = 0
    curr_idx = 0

    iteration_idxs = random_sequential_idxs(total_num_examples)

    model.hidden1 = model.init_hidden()
    model.hidden2 = model.init_hidden()

    total_loss = 0
    steps = 0

    while True:
        if curr_idx + BATCH_SIZE > total_num_examples:
            break

        next_batch = iteration_idxs[curr_idx:curr_idx + BATCH_SIZE]
        curr_idx += BATCH_SIZE
        # Step 1. Remember that Pytorch accumulates gradients.
        # We need to clear them out before each instance
        optimizer.zero_grad()
        model.hidden1 = repackage_hidden(model.hidden1)
        model.hidden2 = repackage_hidden(model.hidden2)

        pair_left_elem = torch.tensor(training_data[next_batch], dtype=torch.float32)
        pair_right_elem = torch.tensor(training_data_flipped[next_batch], dtype=torch.float32)

        # Also, we need to clear out the hidden state of the LSTM,
        # detaching it from its history on the last instance.

        similarity = model(pair_left_elem, pair_right_elem)
        total += BATCH_SIZE

        target_vals = torch.tensor([training_labels[idx] for idx in next_batch])
        target_vals = torch.tensor(target_vals, dtype=torch.long, requires_grad=False)
        loss = loss_function(similarity, target_vals)

        if steps % 32 == 0:
            print('Percentage complete: {}%'.format(float(100*steps*BATCH_SIZE)/total_num_examples))

        total_loss += loss.item()
        steps += 1

        # loss.backward(retain_graph=True)
        loss.backward()
        optimizer.step()

    print("Avg. loss for Epoch {} is {}".format(epoch, float(total_loss)/steps))

# # See what the scores are after training
# with torch.no_grad():
#     inputs = prepare_sequence(training_data[0][0], word_to_ix)
#     tag_scores = model(inputs)
#
#     # The sentence is "the dog ate the apple".  i,j corresponds to score for tag j
#     # for word i. The predicted tag is the maximum scoring tag.
#     # Here, we can see the predicted sequence below is 0 1 2 0 1
#     # since 0 is index of the maximum value of row 1,
#     # 1 is the index of maximum value of row 2, etc.
#     # Which is DET NOUN VERB DET NOUN, the correct sequence!
#     print(tag_scores)

#2

There might be some small issues in your code:

  • Why are you using retain_graph=True in loss.backward()? As your code is quite long, I couldn’t check it properly, but this might be a source of growing memory. Do you get an error if you don’t specify it?
  • Currently you are re-initializing the optimizer in each epoch. All internal estimates of Adam will be lost, which result in bad training.
  • Variables are deprecated since PyTorch 0.4.0. If you are using a newer version, you can just use tensors instead.
  • The loss tensor will be accumulated in total_loss, which will increase the memory usage, as the computation graph is attached to each loss tensor. If you just need total_loss for printing/debugging purposes, you should store it as total_loss += loss.item()

I think repackage_hidden should work fine.

Let me know, if these points help somehow.


(Michael Li) #3

Thanks so much for the comments! I did get an error without retain_graph=True, I can try removing that and checking again.

Good catch on the optimizer being reinitialized every epoch.

I changed the code to not use Variables any more, thanks for that.

I realized the loss was being accumulated which is a dumb mistake. Have corrected that. Will try running and see how the code performs.