I added nce loss to the word_language_model example in this fork. Convergence is slower (measured by number of epochs) than using ce but that may have to do with tuning learning rates, gradient clipping, etc. However, my question is about processing speed. The output layer speedup should be roughly |V| / (K+1) where |V| is vocab size and K is the number of noise samples per target word (plus some overhead). However, my nce implementation is much slower than ce.
I am new to pytorch and I am hoping that more experienced users could point to any obvious fixes to speed up the computations. There are a bunch of index_select() and repeat() operations in linear_nce.py which are probably unnecessary but I could not find any relevant sparse matrix and broadcasting alternatives in the documents. Below are the first couple of epochs of nce and ce runs. Thanks for any suggestions.
NCE
python main.py --cuda --epochs 4 --loss nce --num_noise 100
| epoch 1 | 200/ 1327 batches | lr 20.00 | ms/batch 39.14 | nce_loss 5.79
| epoch 1 | 400/ 1327 batches | lr 20.00 | ms/batch 37.79 | nce_loss 5.40
| epoch 1 | 600/ 1327 batches | lr 20.00 | ms/batch 37.75 | nce_loss 5.17
| epoch 1 | 800/ 1327 batches | lr 20.00 | ms/batch 37.73 | nce_loss 5.06
| epoch 1 | 1000/ 1327 batches | lr 20.00 | ms/batch 36.82 | nce_loss 4.91
| epoch 1 | 1200/ 1327 batches | lr 20.00 | ms/batch 37.81 | nce_loss 4.81
end of epoch 1 | time: 51.17s | valid loss 6.02 | valid ppl 410.86
| epoch 2 | 200/ 1327 batches | lr 20.00 | ms/batch 38.00 | nce_loss 4.69
| epoch 2 | 400/ 1327 batches | lr 20.00 | ms/batch 37.79 | nce_loss 4.61
| epoch 2 | 600/ 1327 batches | lr 20.00 | ms/batch 37.80 | nce_loss 4.55
| epoch 2 | 800/ 1327 batches | lr 20.00 | ms/batch 37.85 | nce_loss 4.50
| epoch 2 | 1000/ 1327 batches | lr 20.00 | ms/batch 37.77 | nce_loss 4.49
| epoch 2 | 1200/ 1327 batches | lr 20.00 | ms/batch 37.78 | nce_loss 4.40
| end of epoch 2 | time: 50.99s | valid loss 5.70 | valid ppl 299.02
CE
python main.py --cuda --epochs 2 --loss ce
| epoch 1 | 200/ 1327 batches | lr 20.00 | ms/batch 14.67 | loss 6.94 | ppl 1031.24
| epoch 1 | 400/ 1327 batches | lr 20.00 | ms/batch 12.48 | loss 6.31 | ppl 551.36
| epoch 1 | 600/ 1327 batches | lr 20.00 | ms/batch 12.48 | loss 6.04 | ppl 420.70
| epoch 1 | 800/ 1327 batches | lr 20.00 | ms/batch 12.50 | loss 5.78 | ppl 323.96
| epoch 1 | 1000/ 1327 batches | lr 20.00 | ms/batch 12.41 | loss 5.64 | ppl 281.56
| epoch 1 | 1200/ 1327 batches | lr 20.00 | ms/batch 12.38 | loss 5.49 | ppl 241.55
| end of epoch 1 | time: 17.80s | valid loss 5.40 | valid ppl 220.75
| epoch 2 | 200/ 1327 batches | lr 20.00 | ms/batch 12.45 | loss 5.40 | ppl 220.57
| epoch 2 | 400/ 1327 batches | lr 20.00 | ms/batch 12.44 | loss 5.33 | ppl 206.70
| epoch 2 | 600/ 1327 batches | lr 20.00 | ms/batch 12.43 | loss 5.28 | ppl 197.03
| epoch 2 | 800/ 1327 batches | lr 20.00 | ms/batch 12.41 | loss 5.17 | ppl 176.66
| epoch 2 | 1000/ 1327 batches | lr 20.00 | ms/batch 12.40 | loss 5.17 | ppl 175.13
| epoch 2 | 1200/ 1327 batches | lr 20.00 | ms/batch 12.45 | loss 5.06 | ppl 157.51
| end of epoch 2 | time: 17.34s | valid loss 5.10 | valid ppl 164.46