Hi, I’m training a network with LSTM component in it. The LSTM part turns out to be the bottleneck when I profiled the model. Strangely, it has very high CPU time. I don’t know why this is happening. Any help would be appreciated.
Below is the sample code I used to profile the LSTM part in the model.
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.autograd as autograd
def repackage_hidden(h):
""" to reduce memory usage"""
if type(h) == Variable:
return Variable(h.data)
else:
return tuple(repackage_hidden(v) for v in h)
class GatePolicy(nn.Module):
def __init__(self, emb_dim, hidden_dim, num_layers=2, rnn_type='lstm'):
super(GatePolicy, self).__init__()
self.rnn_type = rnn_type
self.emb_dim = emb_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
if self.rnn_type == 'lstm':
self.rnn = nn.LSTM(emb_dim, hidden_dim, self.num_layers).cuda()
else:
self.rnn = None
self.hidden = None
def init_hidden(self, batch_size):
# The axes semantics are (num_layers, minibatch_size, hidden_dim)
return (autograd.Variable(torch.zeros(self.num_layers, batch_size,
self.hidden_dim).cuda()),
autograd.Variable(torch.zeros(self.num_layers, batch_size,
self.hidden_dim).cuda()))
def repackage_hidden(self):
self.hidden = repackage_hidden(self.hidden)
def forward(self, x):
# Take the convolution output of each step
batch_size = x.size(0)
# self.rnn.flatten_parameters()
x = x.view(1, batch_size, -1)
out, self.hidden = self.rnn(x, self.hidden)
out = out.squeeze()
return out
if __name__ == '__main__':
model = GatePolicy(128, 128, num_layers=2).cuda()
model = torch.nn.DataParallel(model)
input = torch.randn(1, 128, 1, 1).cuda()
input_var = Variable(input)
with torch.cuda.profiler.profile():
model(input_var)
with torch.autograd.profiler.emit_nvtx():
model(input_var)
Also, here is the profile. I was using a single V-100 card.
------------ --------------- --------------- --------------- --------------- ---------------
Name CPU time CUDA time Calls CPU total CUDA total
------------ --------------- --------------- --------------- --------------- ---------------
Scatter 212.636us 0.000us 1 212.636us 0.000us
view 12.404us 0.000us 1 12.404us 0.000us
CudnnRNN 2108.855us 17.856us 1 2108.855us 17.856us
squeeze 4.888us 0.000us 1 4.888us 0.000us