High CPU time with LSTM

Hi, I’m training a network with LSTM component in it. The LSTM part turns out to be the bottleneck when I profiled the model. Strangely, it has very high CPU time. I don’t know why this is happening. Any help would be appreciated.

Below is the sample code I used to profile the LSTM part in the model.

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.autograd as autograd

def repackage_hidden(h):
    """ to reduce memory usage"""
    if type(h) == Variable:
        return Variable(h.data)
        return tuple(repackage_hidden(v) for v in h)

class GatePolicy(nn.Module):
    def __init__(self, emb_dim, hidden_dim, num_layers=2, rnn_type='lstm'):
        super(GatePolicy, self).__init__()
        self.rnn_type = rnn_type
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        if self.rnn_type == 'lstm':
            self.rnn = nn.LSTM(emb_dim, hidden_dim, self.num_layers).cuda()
            self.rnn = None

        self.hidden = None

    def init_hidden(self, batch_size):
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        return (autograd.Variable(torch.zeros(self.num_layers, batch_size,
                autograd.Variable(torch.zeros(self.num_layers, batch_size,

    def repackage_hidden(self):
        self.hidden = repackage_hidden(self.hidden)

    def forward(self, x):
        # Take the convolution output of each step
        batch_size = x.size(0)
        # self.rnn.flatten_parameters()
        x = x.view(1, batch_size, -1)
        out, self.hidden = self.rnn(x, self.hidden)
        out = out.squeeze()
        return out

if __name__ == '__main__':
    model = GatePolicy(128, 128, num_layers=2).cuda()
    model = torch.nn.DataParallel(model)

    input = torch.randn(1, 128, 1, 1).cuda()
    input_var = Variable(input)

    with torch.cuda.profiler.profile():
        with torch.autograd.profiler.emit_nvtx():

Also, here is the profile. I was using a single V-100 card.

------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                 CPU time        CUDA time            Calls        CPU total       CUDA total
------------  ---------------  ---------------  ---------------  ---------------  ---------------
Scatter             212.636us          0.000us                1        212.636us          0.000us
view                 12.404us          0.000us                1         12.404us          0.000us
CudnnRNN           2108.855us         17.856us                1       2108.855us         17.856us
squeeze               4.888us          0.000us                1          4.888us          0.000us