Hello,
It seems faster to put the dropout outside of the stacked RNN module.
Note that this is not true without the bidirectional case.
Can you explain what makes this difference ?
def std_fw(rnn, src):
return rnn(src)
def split_fw(rnn1, rnn2, rnn3, dropout, src):
output, _ = rnn1(src)
output = torch.nn.utils.rnn.PackedSequence(
torch.nn.functional.dropout(output.data, dropout, True), batch_sizes=output.batch_sizes
)
output, _ = rnn2(output)
output = torch.nn.utils.rnn.PackedSequence(
torch.nn.functional.dropout(output.data, dropout, True), batch_sizes=output.batch_sizes
)
return rnn3(output)
def test():
data = torch.randn(1000, 64, 14)
lengths = torch.randint(1, 1001, (64,)).sort(descending=True)[0]
device = torch.device("cuda:0")
src = torch.nn.utils.rnn.pack_padded_sequence(data, lengths).to(device)
std_rnn = torch.nn.RNN(14, 53, 3, dropout=0.5, bidirectional=True).to(device)
start = time.time()
x = std_fw(std_rnn, src)
print(f"std_fw: {time.time() - start:.6}s")
split_rnn1 = torch.nn.RNN(14, 53, 1, dropout=0, bidirectional=True).to(device)
split_rnn2 = torch.nn.RNN(2 * 53, 53, 1, dropout=0, bidirectional=True).to(device)
split_rnn3 = torch.nn.RNN(2 * 53, 53, 1, dropout=0, bidirectional=True).to(device)
start = time.time()
x = split_fw(split_rnn1, split_rnn2, split_rnn3, 0.5, src)
print(f"split_fw: {time.time() - start:.6}s")
This quick test shows (the results are more obvious with timeit):
std_fw: 0.094436s
split_fw: 0.055170s