Slow backward passes in RNN code

Hi,
I’m new to pytorch. My forward and backward passes are very very slow, are there some clear optimization I could be making to this code which I am not aware of?

class SpatialAttDRQNBody(nn.Module):
    def __init__(self, in_channels=4):
        super(SpatialAttDRQNBody, self).__init__()
        self.feature_dim = 256
        self.rnn_input_dim = 256
        self.batch_size = 1
        self.unroll  = 4
        in_channels = 1 # 
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4, bias= False)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, bias = False)
        self.conv3 = nn.Conv2d(64, 256, kernel_size=3, stride=1, bias  = False)
        self.att1 = nn.Linear(self.feature_dim , self.feature_dim, bias= False)
        self.att2 = nn.Linear(self.feature_dim, self.feature_dim, bias = False)
        self.lstm = nn.LSTM(self.rnn_input_dim, self.feature_dim, num_layers = 1)
        
        self.hidden = self.init_hidden()

    def init_hidden(self, num_layers = 1, batch = 1): 
        # initializing the hidden and cell states
        return (autograd.Variable(torch.zeros(num_layers, batch,self.feature_dim)).cuda(),
                autograd.Variable(torch.zeros(num_layers, batch, self.feature_dim)).cuda())

    def repackage_hidden(self, h):
        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
           return tuple(self.repackage_hidden(state) for state in h)

    def forward(self, x):
        batch = x.size(0)
        output = torch.Tensor()
    
        xchunks= torch.chunk(x,self.unroll, 1)
        self.hidden = self.init_hidden(batch = batch)
    
        for ts in range(len(xchunks)):
            y = F.relu(self.conv1(xchunks[ts]))
            y = F.relu(self.conv2(y))
            y = F.relu(self.conv3(y)) 
            y = y.view(batch,-1, self.feature_dim,) # (batch) x 49 (input vector) x 256 (dimension)
            hidden= self.hidden[0].view(batch,-1, self.feature_dim) # reshaping hidden state

            # Attention Network
            ht_1 = self.att1(hidden)
            xt_1 = self.att1(y)
            combined_att = ht_1 + xt_1
            combined_att = F.tanh(combined_att)
            combined_att2 = self.att2(combined_att)
            goutput = F.softmax(combined_att2, dim=2)
            goutput = goutput.view(goutput.size(0),self.feature_dim, -1)

            context = torch.bmm(goutput, y) # dot product 
            context = context.view(-1, batch, self.rnn_input_dim)   # Adding dimention for lstm 
            self.hidden = self.repackage_hidden(self.hidden) # repackage hidden
            self.lstm.flatten_parameters()
            output, self.hidden = self.lstm(context, self.hidden) #LSTM
            del context
    
        y = output.view(batch, -1) # flattens output
        return y

Calculating my loss (simplified code):

def step(self):
    if self.total_steps > self.config.exploration_steps:
        experiences = self.replay.sample()
        states, actions, rewards, next_states, terminals = experiences
        states = self.config.state_normalizer(states)
        next_states = self.config.state_normalizer(next_states)
        q_next = self.target_network(next_states).detach()
        if self.config.double_q:
            best_actions = torch.argmax(self.network(next_states), dim=-1)
            q_next = q_next[self.batch_indices, best_actions]
        else:
            q_next = q_next.max(1)[0]
        terminals = tensor(terminals)
        rewards = tensor(rewards)
        q_next = self.config.discount * q_next * (1 - terminals)
        q_next.add_(rewards)
        actions = tensor(actions).long()
        q = self.network(states)
        q = q[self.batch_indices, actions]
        
        loss = F.smooth_l1_loss(q, q_next)
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.network.parameters(), self.config.gradient_clip)
        with config.lock:
            self.optimizer.step()
            del loss
   
    if self.total_steps / self.config.sgd_update_frequency % \
            self.config.target_network_update_freq == 0:
        self.target_network.load_state_dict(self.network.state_dict())