Text classification unexpectedly slow

I wrote a simple demo for short text classification but the speed is unexpectedly slow. When I tried to find out where the bottleneck is, it turns out to be intractable.

At first, the bottleneck is this line:

running_loss += loss.data[0]

However, after commenting out the above line, it slows down at these lines in get_batch() function:

        data = data.cuda()
        target = target.cuda()

Is there any problem in the code? I ran this script on GPU (Titan X) with cuda 8.0, python 2.7, ubuntu 16 and pytorch was installed by pip. The data was randomly generated.

The code is attached below:

import numpy as np
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self, vocab_size, emb_dim):
        """
        :param vocab_size:
                an int value, representing the total number of vocabs in the pre-defined lookup table
        :param emb_dim:
                an int value, representing the dimension of each word vector
        """
        super(Net, self).__init__()
        self.lookup_table = nn.Embedding(vocab_size, emb_dim)
        self.init_embedding()
        self.encoder = nn.Conv2d(in_channels=1,
                                 out_channels=200,
                                 kernel_size=(3, 300))
        self.hid_1 = nn.Linear(200, 200)
        self.hid_2 = nn.Linear(200, 10)
    def forward(self, x, training=False):
        """
        :param x:
                input x is in size of [N, C, H, W]
                N: batch size
                C: number of channel, in text case, this is 1
                H: height, in text case, this is the length of the text
                W: width, in text case, this is the dimension of the embedding
        :param training:
                boolean value, whether the forward is for training purpose
        :return:
                a tensor [N, L], where L is the number of classes
        """
        x = self.lookup_table(x)
        x = x.unsqueeze(1)
        enc = F.relu(self.encoder(x))
        enc = F.max_pool2d(enc, kernel_size=(48, 1))
        enc = enc.squeeze()
        enc = F.dropout(F.relu(self.hid_1(enc)), training=training)
        enc = F.relu(self.hid_2(enc))
        pred_prob = F.softmax(enc)
        return pred_prob
    def init_embedding(self):
        initrange = 0.1
        self.lookup_table.weight.data.uniform_(-initrange, initrange)
def get_batch(source, batch_size, i):
    data = source[0][batch_size * i: batch_size * (i + 1)]
    target = source[1][batch_size * i: batch_size * (i + 1)]
    if torch.cuda.is_available():
        print "moving data..."
        st = time.time()
        data = data.cuda()
        target = target.cuda()
        dt = time.time() - st
        print "moving data time: {}".format(dt)
    return data, target
print "Setting seed..."
seed = 1234
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
train_batch_size = 100
test_batch_size = 100
rng = np.random.RandomState(seed)
num_train_instances = 20000
num_test_instances = 2000
max_text_len = 50
vocab_size = 62639
num_classes = 10
emb_dim = 300
print "Generating random data..."
train_set_numpy = rng.randint(0, vocab_size, (num_train_instances, max_text_len)), rng.randint(0, num_classes, (num_train_instances,))
test_set_numpy = rng.randint(0, vocab_size, (num_test_instances, max_text_len)), rng.randint(0, num_classes, (num_test_instances,))
print "Converting numpy data into Tensor..."
train_set = torch.from_numpy(train_set_numpy[0]), torch.from_numpy(train_set_numpy[1])
test_set = torch.from_numpy(test_set_numpy[0]), torch.from_numpy(test_set_numpy[1])
n_train_batch = train_set[0].size()[0] / train_batch_size + 1
n_test_batch = test_set[0].size()[0] / test_batch_size + 1
model = Net(vocab_size=vocab_size, emb_dim=emb_dim)
if torch.cuda.is_available():
    print "move model to GPU"
    model.cuda()
    print "move done"
print model
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
for epoch in xrange(10):
    running_loss = 0.
    for i in xrange(n_train_batch):
        start_time = time.time()
        print "batch: %d" % i
        text, labels = get_batch(train_set, train_batch_size, i)
        text, labels = Variable(text), Variable(labels)
        print "zero optimizer"
        optimizer.zero_grad()
        print "compute forward"
        st = time.time()
        outputs = model(text, training=True)
        dt = time.time() - st
        print "compute forward time: {}".format(dt)
        print "compute loss"
        st = time.time()
        loss = criterion(outputs, labels)
        dt = time.time() - st
        print "compute loss time: {}".format(dt)
        print "compute backword"
        st = time.time()
        loss.backward()
        dt = time.time() - st
        print "compute backword time: {}".format(dt)
        print "update gradient"
        st = time.time()
        optimizer.step()
        dt = time.time() - st
        print "update gradient time: {}".format(dt)
        print "accumulate loss"
        st = time.time()
        running_loss += loss.data[0]
        dt = time.time() - st
        print "accumulate loss time: {}".format(dt)
        duration = time.time() - start_time
        if i % 1 == 0:
            print "training speed: {}/sec".format(train_batch_size / duration)
            running_loss = 0.
        if i % 4 == 3:
            start_time = time.time()
            correct = 0.
            total = 0.
            for j in xrange(n_test_batch):
                text, labels = get_batch(test_set, test_batch_size, j)
                text, labels = Variable(text), Variable(labels)
                outputs = model(text)
                _, predicted = torch.max(outputs.data, dim=1)
                total += labels.size()[0]
                correct += (predicted == labels.data).sum()
            duration = time.time() - start_time
            print "acc: {}".format(100 * correct / total)
            print "testing speed: {}/sec".format(total / duration)
        print "loop done"

The points where your benchmarking indicates slowdowns are CUDA synchronization points. All host-GPU scalar copies (e.g. accessing individual elements like loss.data[0]), or transferring tensors that aren’t in pined memory (e.g. .cuda()) will make the CPU and GPU synchronize. Your CPU quickly gets through the model definition, and then stop in one of these points, waiting for the GPU to finish computing the forward pass.

If you want to reliably benchmark the compute time of your model do this:

torch.cuda.synchronize()
start = # get start time
output = model(input)
torch.cuda.synchronize()
end = # get end time

Can you expect what “unexpectedly slow” means? What is the runtime and what did you expect? It seems that you’re using a convolutional kernel size of 3x300, and that will be incredibly costly to compute (especially with 200 output channels).

1 Like

PS. you don’t need to pass training argument to forward. You can use self.training and call model.eval() and model.train() to change the flag.

Thanks for the reply!
Now I know why the timing is weird.

For the slow part, I time the training on one batch and print it in the code. It shows that the training speed is around 13 instances per sec. With larger graph using tensorflow, the same machine can easily achieve thousands of instances per sec.

In the pytorch code above, every batch takes about 8 sec.

Thanks.

Try to use more batches, the first one will be always slower, because lot of time will be spent allocating memory that will be cached for subsequent runs. Are the batch sizes and model parameters exactly the same as in tensorflow? I can try to run your model tomorrow, but it would be helpful if you could also provide me with the code for tensorflow implementation.

The 8 sec is not the first batch, it is on average.

The batch size for tensorflow is 1024. There are four parallel convolutional layers for my tensorflow model, the kernel sizes are 1X250, 3X250, 4X250, 5X250. And then some hidden layers. That is a much larger graph than the example I showed above.

Sorry, the tensorflow code has too many dependencies, I don’t think I can simply show it here.

Thank you very much!

And what is the time per batch or element that you can achieve in TensorFlow?

I wrote a demo tensorflow code, and graph is the same as that of pytorch.
On same machine.
tensorflow 0.12

the timing info is:
batch duration: 0.0545980930328 s
training speed: 1831.56580102/sec

here is the code:

import time
import tensorflow as tf
import numpy as np
with tf.Graph().as_default():
    max_text_len = 50
    num_classes = 10
    vocab_size = 62639
    emb_dim = 300
    num_train_instances = 20000
    num_test_instances = 2000
    train_batch_size = 100
    test_batch_size = 100
    # model part
    filter_size = 3
    num_filters = 200
    # generating random data
    seed = 1234
    rng = np.random.RandomState(seed)
    train_set = rng.randint(0, vocab_size, (num_train_instances, max_text_len)), rng.randint(0, num_classes, (num_train_instances))
    test_set = rng.randint(0, vocab_size, (num_test_instances, max_text_len)), rng.randint(0, num_classes, (num_test_instances))
    n_train_batch = train_batch_size / train_batch_size + 1
    n_test_batch = test_batch_size / test_batch_size + 1
    def get_batch(source, i, batch_size):
        data = source[0][i * batch_size: (i + 1) * batch_size]
        target = source[1][i * batch_size: (i + 1) * batch_size]
        return data, target
    #  describe graph
    input_x = tf.placeholder(tf.int32, [None, max_text_len], name="input_x")
    input_y = tf.placeholder(tf.int32, [None], name="input_y")
    with tf.device('/cpu:0'), tf.name_scope("embedding"):
        emb_w = tf.Variable(tf.random_uniform([vocab_size, emb_dim], -1.0, 1.0), name="emb_w")
        embedded = tf.nn.embedding_lookup(emb_w, input_x)
        embedded_expanded = tf.expand_dims(embedded, -1)
    with tf.name_scope("conv-maxpool-%s" % filter_size):
        filter_shape = [filter_size, emb_dim, 1, num_filters]
        W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")
        b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name="b")
        conv = tf.nn.conv2d(
            embedded_expanded,
            W,
            strides=[1, 1, 1, 1],
            padding="VALID",
            name="conv")
        # Apply nonlinearity
        h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
        # Maxpooling over the outputs
        pooled = tf.nn.max_pool(
            h,
            ksize=[1, max_text_len - filter_size + 1, 1, 1],
            strides=[1, 1, 1, 1],
            padding='VALID',
            name="pool")
        h_pool_flat = tf.reshape(pooled, [-1, num_filters])
    with tf.name_scope("dropout"):
        h_drop = tf.nn.dropout(h_pool_flat, 0.5)
    with tf.name_scope("hidden"):
        W = tf.get_variable(
            "h_W",
            shape=[num_filters, num_filters],
            initializer=tf.contrib.layers.xavier_initializer())
        b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name="b")
        hidden_out = tf.nn.xw_plus_b(h_drop, W, b, name="hidden_out")
    with tf.name_scope("output"):
        W = tf.get_variable(
            "o_W",
            shape=[num_filters, num_classes],
            initializer=tf.contrib.layers.xavier_initializer())
        b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b")
        scores = tf.nn.xw_plus_b(hidden_out, W, b, name="scores")
    with tf.name_scope("loss"):
        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(scores, input_y)
# training
    session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
    sess = tf.Session(config=session_conf)
    optimizer = tf.train.AdamOptimizer(1e-3)
    grads_and_vars = optimizer.compute_gradients(losses)
    train_op = optimizer.apply_gradients(grads_and_vars)
    sess.run(tf.global_variables_initializer())
    for epoch in xrange(10):
        for i in xrange(n_train_batch):
            start_time = time.time()
            x_batch, y_batch = get_batch(train_set, i, train_batch_size)
            print x_batch.shape
            print y_batch.shape
            feed_dict = {
                input_x: x_batch,
                input_y: y_batch,
            }
            _, loss_ = sess.run([train_op, losses], feed_dict)
            duration = time.time() - start_time
            print "duration: {} s".format(duration)
            print "speed: {}/sec".format(train_batch_size / duration)

Thank you very much!

Seems like that there are some problems or pitfalls in Conv2d on GPU. When I used in_channels = 1, out_channels = 200, kernel height = 3, kernel width = 300 (original setting) then backward was very slow. I changed it to in channels = 300, out channels = 200, kernel height = 1, kernel width = 3 then model works well. (I got 9500/sec on GTX 1080 8) )

This seems like GPU-specific problem because no problem occurred when ran on CPU.

Hmm yeah that will likely use make cuDNN pick a different algorithm. I don’t think it’s optimized for 3x300, so it probably falls back to a slow gemm implementation, while 1x3 is quite commonly used, and will likely use a smarter algo.

If it is a GPU-specific problem, then why tensorflow code works well with the same kernel size? Thanks

So there is no way to speed it up other than reshaping the input? Thanks

It’s not a GPU specific problem, we’re probably just using different cuDNN calls to select algorithms, and for some reason they fail to pick the fast ones for this use case. I don’t know if there’s no other way, it will take me some time until I properly benchmark it and see what’s the issue. If you want to use it now, just reshape the input.

1 Like

Thank you very much!

Ok, I’ve confirmed the issue. The fix for now is to reshape the inputs as @kim.seonghyeon said, or to set torch.backends.cudnn.enabled = False. Still, the first approach is likely to be faster for you.

With cudnn disabled the training runs at 1k samples/sec for me. With it slows down to 12.

Ok, @colesbury made a good point. The fix is to add

torch.backends.cudnn.benchmark = True

This is a flag that controls cudnn benchmark mode and should be used only if every forward pass uses the same data dimensions. After this your code runs @ 7k/s for training and 64k/s for test on my machine.

Also, a small improvement at test time can be achieved by using volatile=True on the input to the network.

Thanks for the follow up. I will stick to @kim.seonghyeon 's method, where my code get 8K/s for training and 80K/s for testing.

Thanks!

Try adding benchmark anyway. It will probably make it slightly faster.

If I set this:

What do you mean by same data dimension?

Do I have to have a fixed batch size?
If that is the case, what if the data size is not the integer times of batch size?

Thanks.