NCE loss for large output vocabularies

I added nce loss to the word_language_model example in this fork. Convergence is slower (measured by number of epochs) than using ce but that may have to do with tuning learning rates, gradient clipping, etc. However, my question is about processing speed. The output layer speedup should be roughly |V| / (K+1) where |V| is vocab size and K is the number of noise samples per target word (plus some overhead). However, my nce implementation is much slower than ce.

I am new to pytorch and I am hoping that more experienced users could point to any obvious fixes to speed up the computations. There are a bunch of index_select() and repeat() operations in linear_nce.py which are probably unnecessary but I could not find any relevant sparse matrix and broadcasting alternatives in the documents. Below are the first couple of epochs of nce and ce runs. Thanks for any suggestions.

NCE
python main.py --cuda --epochs 4 --loss nce --num_noise 100

| epoch 1 | 200/ 1327 batches | lr 20.00 | ms/batch 39.14 | nce_loss 5.79
| epoch 1 | 400/ 1327 batches | lr 20.00 | ms/batch 37.79 | nce_loss 5.40
| epoch 1 | 600/ 1327 batches | lr 20.00 | ms/batch 37.75 | nce_loss 5.17
| epoch 1 | 800/ 1327 batches | lr 20.00 | ms/batch 37.73 | nce_loss 5.06
| epoch 1 | 1000/ 1327 batches | lr 20.00 | ms/batch 36.82 | nce_loss 4.91
| epoch 1 | 1200/ 1327 batches | lr 20.00 | ms/batch 37.81 | nce_loss 4.81

end of epoch 1 | time: 51.17s | valid loss 6.02 | valid ppl 410.86

| epoch 2 | 200/ 1327 batches | lr 20.00 | ms/batch 38.00 | nce_loss 4.69
| epoch 2 | 400/ 1327 batches | lr 20.00 | ms/batch 37.79 | nce_loss 4.61
| epoch 2 | 600/ 1327 batches | lr 20.00 | ms/batch 37.80 | nce_loss 4.55
| epoch 2 | 800/ 1327 batches | lr 20.00 | ms/batch 37.85 | nce_loss 4.50
| epoch 2 | 1000/ 1327 batches | lr 20.00 | ms/batch 37.77 | nce_loss 4.49
| epoch 2 | 1200/ 1327 batches | lr 20.00 | ms/batch 37.78 | nce_loss 4.40

| end of epoch 2 | time: 50.99s | valid loss 5.70 | valid ppl 299.02

CE
python main.py --cuda --epochs 2 --loss ce

| epoch 1 | 200/ 1327 batches | lr 20.00 | ms/batch 14.67 | loss 6.94 | ppl 1031.24
| epoch 1 | 400/ 1327 batches | lr 20.00 | ms/batch 12.48 | loss 6.31 | ppl 551.36
| epoch 1 | 600/ 1327 batches | lr 20.00 | ms/batch 12.48 | loss 6.04 | ppl 420.70
| epoch 1 | 800/ 1327 batches | lr 20.00 | ms/batch 12.50 | loss 5.78 | ppl 323.96
| epoch 1 | 1000/ 1327 batches | lr 20.00 | ms/batch 12.41 | loss 5.64 | ppl 281.56
| epoch 1 | 1200/ 1327 batches | lr 20.00 | ms/batch 12.38 | loss 5.49 | ppl 241.55

| end of epoch 1 | time: 17.80s | valid loss 5.40 | valid ppl 220.75

| epoch 2 | 200/ 1327 batches | lr 20.00 | ms/batch 12.45 | loss 5.40 | ppl 220.57
| epoch 2 | 400/ 1327 batches | lr 20.00 | ms/batch 12.44 | loss 5.33 | ppl 206.70
| epoch 2 | 600/ 1327 batches | lr 20.00 | ms/batch 12.43 | loss 5.28 | ppl 197.03
| epoch 2 | 800/ 1327 batches | lr 20.00 | ms/batch 12.41 | loss 5.17 | ppl 176.66
| epoch 2 | 1000/ 1327 batches | lr 20.00 | ms/batch 12.40 | loss 5.17 | ppl 175.13
| epoch 2 | 1200/ 1327 batches | lr 20.00 | ms/batch 12.45 | loss 5.06 | ppl 157.51

| end of epoch 2 | time: 17.34s | valid loss 5.10 | valid ppl 164.46

I have a similar implementation for NCE. After nvprof the training script, I found that the CUDA kernel sampleMultinomialWithReplacement (https://github.com/pytorch/pytorch/blob/0f65c9267d5ec55584b0ec65acb5374c95af9c16/torch/lib/THC/THCTensorRandom.cuh), i.e. torch.multinomial costs about 75% of the training time.

Since I am using the same noise samples for all targets in the minibatch, torch.multinomial costs < 10% of the total (for 100 noise samples). About 2/3 of the time goes in the index_select() and noise probability calculations (pointwise multiplies). There must be more efficient sparse matrix alternatives but I haven’t found them in the docs.

How do you profile your execution time? The most operations in Pytorch are async, which returns right after being called. So it makes sense to add torch.cuda.synchronize() before and after your critical code lines when you are profiling.

Thanks for your suggestion. I was missing a torch.cuda.synchronize() before the call to torch.multinomial(), an oversight. I added a more efficient sampling scheme described here, and fixed a few other inefficiencies in linear_nce() and now NCE does run a bit faster than CE on the 10K vocabulary PTB and much faster on larger vocabularies.

| epoch 1 | 200/ 1327 batches | lr 20.00 | ms/batch 11.86 | nce_loss 5.90
| epoch 1 | 400/ 1327 batches | lr 20.00 | ms/batch 10.10 | nce_loss 5.39
| epoch 1 | 600/ 1327 batches | lr 20.00 | ms/batch 10.09 | nce_loss 5.17
| epoch 1 | 800/ 1327 batches | lr 20.00 | ms/batch 10.09 | nce_loss 5.06
| epoch 1 | 1000/ 1327 batches | lr 20.00 | ms/batch 10.09 | nce_loss 4.94
| epoch 1 | 1200/ 1327 batches | lr 20.00 | ms/batch 10.12 | nce_loss 4.82

| end of epoch 1 | time: 14.59s | valid loss 6.02 | valid ppl 412.35

| epoch 2 | 200/ 1327 batches | lr 20.00 | ms/batch 10.19 | nce_loss 4.71
| epoch 2 | 400/ 1327 batches | lr 20.00 | ms/batch 10.13 | nce_loss 4.62
| epoch 2 | 600/ 1327 batches | lr 20.00 | ms/batch 10.11 | nce_loss 4.56
| epoch 2 | 800/ 1327 batches | lr 20.00 | ms/batch 10.12 | nce_loss 4.52
| epoch 2 | 1000/ 1327 batches | lr 20.00 | ms/batch 10.16 | nce_loss 4.50
| epoch 2 | 1200/ 1327 batches | lr 20.00 | ms/batch 10.12 | nce_loss 4.41

| end of epoch 2 | time: 14.30s | valid loss 5.69 | valid ppl 297.10

Hi,
I am using the implementation for NCE training found in https://github.com/parthaca/examples/tree/master/word_language_model

After a few (6-9) backprops I start getting Nans. I suspect I generate very small numbers.
And the overall loss for the different batches results in Nan, any idea what may have gone wrong?
Should I normalize some of these outputs to make it work?
Would appreciate your advice,
Shiran

My implementation of NCE loss did not use a stable log-sum-exp computation, which may have caused your nan problem. I changed the implementation to use torch.nn.functional.binary_cross_entropy_with_logits(), which should be numerically stable.

Thanks for improving your code, there were two small places I needed to fix as they didn’t go through well on my machine. one in linear_nce.py:
unigram_prob = torch.Tensor(unigram_prob)+1e-10
self.unigram_prob = Variable(unigram_prob, requires_grad=False).cuda()

and the second in ice_loss.py:
loss = F.binary_cross_entropy_with_logits(logits, nce_target)

The problem this time was that it did it seemed to get stuck at some point and it took it forever to complete a round (I played with a toy example that previously produced the 'Nan’s, and this time it never riches the point of printing the ppl as it never finished an epoch). I didn’t debug further to investigate the problem, but I wonder whether it runs on your end?

I am not sure why those 2 changes are needed. In fact, it is necessary to set reduce = False. The default ptb example runs fine for me (only slightly faster than ce because the vocabulary size is only 10K).

Go it. Thanks! We are probably not using the same versions as I did not compile pytorch from scratch and so I have the following function signature:
torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=None, size_average=True)
Thanks anyway :wink:

I wanna cautious with what I say but after I copied the function that you were using (but still using the binaries w/o compiling) it produced meaningful results!
So thanks a lot spartha! :smile: