Hi,
I’m new to pytorch. My forward and backward passes are very very slow, are there some clear optimization I could be making to this code which I am not aware of?
class SpatialAttDRQNBody(nn.Module):
def __init__(self, in_channels=4):
super(SpatialAttDRQNBody, self).__init__()
self.feature_dim = 256
self.rnn_input_dim = 256
self.batch_size = 1
self.unroll = 4
in_channels = 1 #
self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4, bias= False)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, bias = False)
self.conv3 = nn.Conv2d(64, 256, kernel_size=3, stride=1, bias = False)
self.att1 = nn.Linear(self.feature_dim , self.feature_dim, bias= False)
self.att2 = nn.Linear(self.feature_dim, self.feature_dim, bias = False)
self.lstm = nn.LSTM(self.rnn_input_dim, self.feature_dim, num_layers = 1)
self.hidden = self.init_hidden()
def init_hidden(self, num_layers = 1, batch = 1):
# initializing the hidden and cell states
return (autograd.Variable(torch.zeros(num_layers, batch,self.feature_dim)).cuda(),
autograd.Variable(torch.zeros(num_layers, batch, self.feature_dim)).cuda())
def repackage_hidden(self, h):
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(self.repackage_hidden(state) for state in h)
def forward(self, x):
batch = x.size(0)
output = torch.Tensor()
xchunks= torch.chunk(x,self.unroll, 1)
self.hidden = self.init_hidden(batch = batch)
for ts in range(len(xchunks)):
y = F.relu(self.conv1(xchunks[ts]))
y = F.relu(self.conv2(y))
y = F.relu(self.conv3(y))
y = y.view(batch,-1, self.feature_dim,) # (batch) x 49 (input vector) x 256 (dimension)
hidden= self.hidden[0].view(batch,-1, self.feature_dim) # reshaping hidden state
# Attention Network
ht_1 = self.att1(hidden)
xt_1 = self.att1(y)
combined_att = ht_1 + xt_1
combined_att = F.tanh(combined_att)
combined_att2 = self.att2(combined_att)
goutput = F.softmax(combined_att2, dim=2)
goutput = goutput.view(goutput.size(0),self.feature_dim, -1)
context = torch.bmm(goutput, y) # dot product
context = context.view(-1, batch, self.rnn_input_dim) # Adding dimention for lstm
self.hidden = self.repackage_hidden(self.hidden) # repackage hidden
self.lstm.flatten_parameters()
output, self.hidden = self.lstm(context, self.hidden) #LSTM
del context
y = output.view(batch, -1) # flattens output
return y
Calculating my loss (simplified code):
def step(self):
if self.total_steps > self.config.exploration_steps:
experiences = self.replay.sample()
states, actions, rewards, next_states, terminals = experiences
states = self.config.state_normalizer(states)
next_states = self.config.state_normalizer(next_states)
q_next = self.target_network(next_states).detach()
if self.config.double_q:
best_actions = torch.argmax(self.network(next_states), dim=-1)
q_next = q_next[self.batch_indices, best_actions]
else:
q_next = q_next.max(1)[0]
terminals = tensor(terminals)
rewards = tensor(rewards)
q_next = self.config.discount * q_next * (1 - terminals)
q_next.add_(rewards)
actions = tensor(actions).long()
q = self.network(states)
q = q[self.batch_indices, actions]
loss = F.smooth_l1_loss(q, q_next)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.network.parameters(), self.config.gradient_clip)
with config.lock:
self.optimizer.step()
del loss
if self.total_steps / self.config.sgd_update_frequency % \
self.config.target_network_update_freq == 0:
self.target_network.load_state_dict(self.network.state_dict())