LSTM training speed

I use a lstm code like this

class StackedLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0):
    super(StackedLSTM, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.dropout = nn.Dropout(dropout)
    self.layers = nn.ModuleList()

    for i in range(num_layers):
        self.layers.append(nn.LSTMCell(input_size, hidden_size))
        input_size = hidden_size

def forward(self, input, hidden):
    h_0, c_0 = hidden
    h_1, c_1 = [], []
    for i, layer in enumerate(self.layers):
        h_1_i, c_1_i = layer(input, (h_0[i], c_0[i]))
        input = h_1_i
            h_1_i, c_1_i = layer(input, (h_0[i], c_0[i]))
        input = h_1_i
        if i != self.num_layers:
            input = self.dropout(input)
        h_1 += [h_1_i]
        c_1 += [c_1_i]

    h_1 = torch.stack(h_1)
    c_1 = torch.stack(c_1)

    return input, (h_1, c_1)

when i compare it to the nn.LSTM (official API) on language model, the perplexity is close but the training speed differ a lot, my code may need 170s for one epoch while use nn.LSTM will only take 50s.

Can someone tell me the reason?

nn.LSTM uses the optimized CUDNN LSTM kernel by default, which avoids all the overhead of building a graph that contains hundreds of separate mathematical operations and the overhead of launching those operations as separate kernels. You’ll get the same slow performance by setting torch.backends.cudnn.enabled=False.

1 Like