Speed up for layer norm LSTM


(Wei Cao) #1

Greetings!

I implemented a layer-normalized LSTMCell from scratch. Everything works fine but it is much slower than the original LSTM. I noticed that the original LSTMCell is based on the LSTMFused_updateOutput which is implemented with C code. I am wandering if there is some easy way to speed up the LayerNorm LSTM without modifying the C implementation in the backend? Thank you very much!

Here is my code

 class LayerNorm(nn.Module):
    def __init__(self, nb_features, eps = 1e-5):
        super(LayerNorm, self).__init__()
        self.eps = eps
        self.gain = nn.Parameter(torch.ones(1, nb_features))
        self.bias = nn.Parameter(torch.zeros(1, nb_features))

    def forward(self, input):
        mean = input.mean(1).expand_as(input)
        std = input.std(1).expand_as(input)
        x = (input - mean) / (std + self.eps)
        return x * self.gain.expand_as(x) + self.bias.expand_as(x)

class LayerNormLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LayerNormLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size))

        self.bias_ih = Parameter(torch.Tensor(4 * hidden_size))
        self.bias_hh = Parameter(torch.Tensor(4 * hidden_size))

        self.ln_ih = LayerNorm(4 * hidden_size)
        self.ln_hh = LayerNorm(4 * hidden_size)
        self.ln_ho = LayerNorm(hidden_size)

    def forward(self, input, hidden):
        hx, cx = hidden
        gates = self.ln_ih(F.linear(input, self.weight_ih, self.bias_ih)) + self.ln_hh(F.linear(hx, self.weight_hh, self.bias_hh))

        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = F.sigmoid(ingate)
        forgetgate = F.sigmoid(forgetgate)
        cellgate = F.tanh(cellgate)
        outgate = F.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * F.tanh(self.ln_ho(cy))

        return hy, cy

(Wei Cao) #2

Any one can help? The training speed is terrified :sob:


#3

You could send your gates to the fused pointwise backend and recalculate hy. Would give some gains… see https://github.com/pytorch/pytorch/blob/ceb4f84d12304d03a6a46693e54390869c0c208e/torch/nn/_functions/rnn.py#L23-L28

To get really strong performance one would want to merge mean and std into a single kernel, and then

x = (input - mean) / (std + self.eps)
x * self.gain.expand_as(x) + self.bias.expand_as(x)

into a single kernel. Which is a sizable effort.


(Kris Cao) #4

Hi

Sorry to bring up a dead old thread, but I’ve recently reimplemented a LayerNorm LSTM (using the code above). Even with the suggestion to use the fused backend, I’m getting some pretty horrible speeds - about half the speed of the native LSTMCell implementation.

I suspect there aren’t any other ways to get a speedup, and that most of the difference is due to the native implementation being able to call the CuDNN optimized LSTM implementation directly. Is there a way to get Layer Norm into the CuDNN LSTM implemenation?


#5

@kroscoo, Were you able to figure out a way to address this problem? I am still looking for a solution.


(Simon Wang) #6

Using the Layer Norm from official repo (not a custom one like above) should speed it up a lot


#7

@SimonW I am sorry, I was not aware that pytorch had LSTM/GRU with layer norm built into it. I could not find it. Can you please point me to it? Thanks a lot.


#8

https://pytorch.org/docs/stable/nn.html#torch.nn.LayerNorm


(ravvv) #9

Hi,

Does anyone find a way to use LayerNorm LSTM with cudnn? I am also tackling this problem.
Thanks a lot.


(Kris Cao) #10

As of 1.0 the fused pointwise backend is no longer importable. This is causing some pretty bad regressions in my model performance - is it possible to fix this? Even rewriting the LayerNormLSTM in Torch Script is a bit slower than it was before.


#11

@kriscoo I am investigating better performance of RNNs in TorchScript. Do you have a code snippet for how you’re writing the LayerNormLSTM?