Dropout faster without stacked RNN

Hello,

It seems faster to put the dropout outside of the stacked RNN module.
Note that this is not true without the bidirectional case.
Can you explain what makes this difference ?

def std_fw(rnn, src):
    return rnn(src)

def split_fw(rnn1, rnn2, rnn3, dropout, src):
    output, _ = rnn1(src)
    output = torch.nn.utils.rnn.PackedSequence(
        torch.nn.functional.dropout(output.data, dropout, True), batch_sizes=output.batch_sizes
    )
    output, _ = rnn2(output)
    output = torch.nn.utils.rnn.PackedSequence(
        torch.nn.functional.dropout(output.data, dropout, True), batch_sizes=output.batch_sizes
    )    
    return rnn3(output)
    
def test():
    data = torch.randn(1000, 64, 14)
    lengths = torch.randint(1, 1001, (64,)).sort(descending=True)[0]
    device = torch.device("cuda:0")
    
    src = torch.nn.utils.rnn.pack_padded_sequence(data, lengths).to(device)

    std_rnn = torch.nn.RNN(14, 53, 3, dropout=0.5, bidirectional=True).to(device)
    start = time.time()
    x = std_fw(std_rnn, src)
    print(f"std_fw: {time.time() - start:.6}s")   
    
    split_rnn1 = torch.nn.RNN(14, 53, 1, dropout=0, bidirectional=True).to(device)
    split_rnn2 = torch.nn.RNN(2 * 53, 53, 1, dropout=0, bidirectional=True).to(device)
    split_rnn3 = torch.nn.RNN(2 * 53, 53, 1, dropout=0, bidirectional=True).to(device)
    start = time.time()
    x = split_fw(split_rnn1, split_rnn2, split_rnn3, 0.5, src)
    print(f"split_fw: {time.time() - start:.6}s")   

This quick test shows (the results are more obvious with timeit):

std_fw: 0.094436s
split_fw: 0.055170s

My guess is that this is not due to dropout, but due to how you stack RNN layers.
Can you set dropout=0 in both the cases and average the time for say 20 runs?

Here are the results (average on 20 runs):

dropout = 0

std_fw: 0.0788724s
split_fw: 0.0804384s

dropout = 0.5

std_fw: 0.105565s
split_fw: 0.0755856s

Are you getting exactly the same output in both cases? There might be a slight difference in the implementation.
Looks like in std_fw, you are adding dropout on the RNN, whereas in split_fw you are adding dropout on the output cells.

Of course, the outputs are not the same because I didn’t control the random state here.
But the 2 models are the same in theory.

According to the nn.RNN documentation
dropout – If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0

Can you set the random state and try measuring the performance again.
Ideally if the models are same, there shouldn’t be such drastic different in performance.

Hi, I played a little bit with torch.autograd.profiler

It seems that the dropout alone is really fast and according to the code (as far as I can understand it):

The stacked standard module is doing the same thing as what I’m doing. So I don’t understand why it’s so slow.
Can it be an issue with at::dropout where the dispatch doesn’t work as expected ?

And I want to correct what I said before, it’s also true with unidirectional and without a PackedSequence. This strange behavior also happens with LSTM and GRU.

Here are the results (for 20 iterations)

std_fw: 0.115978s
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
             aten::_cudnn_rnn        99.44%        2.305s        99.56%        2.308s     115.382ms        2.336s        99.94%        2.336s     116.806ms           0 b           0 b       1.69 Gb      -1.11 Gb            20  
                  aten::empty         0.15%       3.556ms         0.15%       3.556ms      25.397us       0.000us         0.00%       0.000us       0.000us           0 b           0 b       2.80 Gb       2.80 Gb           140  
                 aten::select         0.09%       2.106ms         0.12%       2.780ms      69.504us     284.239us         0.01%     284.239us       7.106us           0 b           0 b           0 b           0 b            40  
    aten::_local_scalar_dense         0.07%       1.561ms         0.07%       1.561ms      39.035us     211.103us         0.01%     211.103us       5.278us           0 b           0 b           0 b           0 b            40  
                   aten::set_         0.05%       1.175ms         0.05%       1.175ms      58.767us     161.768us         0.01%     161.768us       8.088us           0 b           0 b           0 b           0 b            20  
                   aten::item         0.04%     858.459us         0.10%       2.420ms      60.496us     129.619us         0.01%     340.722us       8.518us           0 b           0 b           0 b           0 b            40  
               aten::rnn_tanh         0.04%     841.586us        99.70%        2.311s     115.536ms     173.881us         0.01%        2.336s     116.824ms           0 b           0 b       1.69 Gb           0 b            20  
             aten::as_strided         0.03%     673.823us         0.03%     673.823us      16.846us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            40  
                  aten::fill_         0.02%     542.642us         0.02%     542.642us      27.132us     153.827us         0.01%     153.827us       7.691us           0 b           0 b           0 b           0 b            20  
                  aten::zeros         0.02%     454.993us         0.08%       1.829ms      91.436us     117.128us         0.01%     379.144us      18.957us           0 b           0 b       1.55 Mb           0 b            20  
                  aten::zero_         0.02%     421.823us         0.04%     964.465us      48.223us     108.189us         0.00%     262.016us      13.101us           0 b           0 b           0 b           0 b            20  
             aten::contiguous         0.02%     383.369us         0.02%     383.369us      19.168us      46.632us         0.00%      46.632us       2.332us           0 b           0 b           0 b           0 b            20  
    aten::cudnn_is_acceptable         0.02%     366.571us         0.02%     366.571us      18.329us      40.576us         0.00%      40.576us       2.029us           0 b           0 b           0 b           0 b            20  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.318s
CUDA time total: 2.337s

split_fw: 0.0534257s
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
             aten::_cudnn_rnn        95.47%        1.015s        96.14%        1.022s      17.038ms        1.021s        95.98%        1.022s      17.037ms           0 b           0 b       1.65 Gb      -3.59 Gb            60  
                  aten::empty         1.01%      10.758ms         1.01%      10.758ms      21.516us       0.000us         0.00%       0.000us       0.000us           0 b           0 b       5.93 Gb       5.93 Gb           500  
                 aten::select         0.52%       5.570ms         0.71%       7.567ms      63.057us       6.854ms         0.64%       6.854ms      57.118us           0 b           0 b           0 b           0 b           120  
    aten::_local_scalar_dense         0.46%       4.872ms         0.46%       4.872ms      40.601us       4.817ms         0.45%       4.817ms      40.141us           0 b           0 b           0 b           0 b           120  
                 aten::stride         0.43%       4.522ms         0.43%       4.522ms      14.131us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           320  
                   aten::set_         0.32%       3.446ms         0.32%       3.446ms      57.429us       3.411ms         0.32%       3.411ms      56.857us           0 b           0 b           0 b           0 b            60  
         aten::_fused_dropout         0.28%       2.979ms         0.91%       9.707ms     242.675us      11.934ms         1.12%      11.934ms     298.349us           0 b           0 b     703.71 Mb           0 b            40  
                   aten::item         0.24%       2.543ms         0.70%       7.415ms      61.789us       2.572ms         0.24%       7.389ms      61.575us           0 b           0 b           0 b           0 b           120  
               aten::rnn_tanh         0.22%       2.306ms        96.97%        1.031s      17.185ms       4.279ms         0.40%        1.031s      17.184ms           0 b           0 b       1.65 Gb           0 b            60  
             aten::as_strided         0.19%       1.996ms         0.19%       1.996ms      16.637us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           120  
                  aten::fill_         0.14%       1.507ms         0.14%       1.507ms      25.113us       1.540ms         0.14%       1.540ms      25.665us           0 b           0 b           0 b           0 b            60  
                aten::dropout         0.14%       1.490ms         1.05%      11.197ms     279.935us       1.428ms         0.13%      13.362ms     334.038us           0 b           0 b     703.71 Mb           0 b            40  
                  aten::zeros         0.12%       1.321ms         0.49%       5.262ms      87.696us       2.482ms         0.23%       5.233ms      87.212us           0 b           0 b       1.55 Mb           0 b            60  
                  aten::zero_         0.12%       1.253ms         0.26%       2.760ms      45.998us       1.211ms         0.11%       2.751ms      45.843us           0 b           0 b           0 b           0 b            60  
    aten::cudnn_is_acceptable         0.10%       1.103ms         0.10%       1.103ms      18.375us       1.101ms         0.10%       1.101ms      18.356us           0 b           0 b           0 b           0 b            60  
             aten::contiguous         0.10%       1.091ms         0.10%       1.091ms      18.175us       1.086ms         0.10%       1.086ms      18.103us           0 b           0 b           0 b           0 b            60  
                     aten::to         0.07%     777.407us         0.07%     777.407us      19.435us      68.049us         0.01%      68.049us       1.701us           0 b           0 b           0 b           0 b            40  
             aten::empty_like         0.06%     676.402us         0.14%       1.451ms      36.284us       0.000us         0.00%       0.000us       0.000us           0 b           0 b     562.97 Mb           0 b            40  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.063s
CUDA time total: 1.064s

I have reported the issue here: Unexpected slow dropout in stacked RNN/LSTM/GRU · Issue #50879 · pytorch/pytorch · GitHub