How to improve sampling performance?


#1

Hi,

I am trying to sample from an autoencoder model (Char RNN) in a small web app. The sampling works perfectly fine, the problem is performance: it’s slow - too slow to use beyond fairly limited sample length. On the container I’m serving from it can take tens of seconds to sample a thousand characters. This has to happen in realtime, so I’m wondering am I missing a trick?

Not surprisingly the majority of time is spent iterating over the number of requested samples, for which the code looks like this:

    for p in range(predict_len):
        output, hidden = model(inp, hidden)

        # Sample from the network as a multinomial distribution
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = torch.multinomial(output_dist, 1)[0]

        # Add predicted character to string and use as next input
        predicted_char = index_to_char[top_i]
        predicted += predicted_char
        inp = index_to_tensor(char_to_index[predicted_char])

And specifically this line:

output, hidden = model(inp, hidden)

Is there any way to speed this up?

Notes:

  1. I see a significant (although still not ideal) speed-up when using GPU as opposed to CPU, however GPU is not available to me for serving the web app.
  2. the model checkpoint is 230Mb on disk, which is obviously not small, so if there is a way to reduce this after training that would be an option as well (not sure how though)

More broadly, how do people serve models in realtime when sampling tends to be this slow? Or is that par for the course really and I’m not going to be able to optimise it much?

Thanks in advance for any advice!


(Alban D) #2

Hi,

The problem is that if all of the time is spent in the forward pass of your network, there isn’t much you can do without modifying this network or increasing your compute capabilities.

  • Using a lighter network (in term of compute, not memory) would be a solution: if you reduce the depth by 2, your runtime is going to be divided by 2. But depending on your application that might not be feasible.
  • Use a GPU instance as backend to speed up fw. Or use multicore cpu instances, that might help as well.
  • If you can’t use properly all the cpu cores you have, possibly batching requests could be possible (so that you can do batch forward passes)?

As a more general idea, you should be able from your model to compute roughly how many operations a single forward pass requires. That will give you an answer on how much hardware you will need to run this at the speed your want.


#3

Hi albanD,

Thanks for your suggestions. With regards to the forward passes are you referring to the training process?

In my case the training process has already been run on a separate, more powerful server with GPUs. So the checkpoint file is ready to use. Unfortunately the web app runs on a tiny web container instance with no GPUs available. I load the checkpoint into memory at startup, and it’s during the sampling (generate) process that I’m encountering performance issues.

The training was done on 3 layers, this proved more or less optimal - going down to 2 means the training loss is not as good. Do you mean that this can make a difference during the generate process as well?


(Alban D) #4

No when I say forward pass, I mean evaluation time when you just perform a forward pass in your net (you call that sampling/generate).
If your network is 2 layers instead of 3, then yes the generate process will be faster, because the line output, hidden = model(inp, hidden) will run faster.


#5

Hi albanD,

Thanks for clarifying.
Having tried that on some older checkpoints with 2 hidden layers it is indeed faster as you say. The 3 layer version remains the best though (1024 network nodes).

Sampling from trained networks with 2 layers and varying numbers of network nodes (512, 1024) suggest that increasing the number of nodes while going down in layers will still incur a performance cost. I guess performance is the price I have to pay for the size and depth of the network, which I need in this case…

I’ll look into the idea of batch forward passes as you suggest and see if I can make progress with that.