Pytorch Memory Leak (Help!)

I have a few lines of code that are contributing to a growing memory leak (torch 1.6.0 and cuda 11.0). I don’t know how to fix them and I would appreciate some feedback.

They are all in one method: agent_update_network_parameters() (shown below). I have marked the lines that have a leak with ‘# MEMORY LEAK’. There are three lines, each with a leak that is contributing to increasing memory usage (1 MB per call).

    @profile
    def agent_update_network_parameters(self):
        """
        Update the parameters for the NN(s).

        Note: This is performed in the following order:
        - value network
        - both q value network's
        - policy network
        """

        self.num_updates += 1

        state, action, reward, next_state, terminal = self.replay_buffer.sample(self.batch_size)

        state = torch.FloatTensor(state).to(device=self.device)
        action = torch.FloatTensor(action).to(device=self.device)
        reward = torch.FloatTensor(reward).unsqueeze(1).to(device=self.device)
        next_state = torch.FloatTensor(next_state).to(device=self.device)
        # MEMORY LEAK in terminal (originally was numpy array of boolean values)
        terminal = torch.FloatTensor(terminal).unsqueeze(1).to(device=self.device)

        # q_value network

        predicted_q_value_1, predicted_q_value_2 = self.q_network(state, action)

        with torch.no_grad():

            next_state_sampled_action, next_state_log_prob, _ = self.policy_network.sample(next_state)

            predicted_target_q_value_1, predicted_target_q_value_2 = self.target_q_network(next_state, next_state_sampled_action)
            estimated_value = torch.min(predicted_target_q_value_1, predicted_target_q_value_2) - self.alpha * next_state_log_prob
            estimated_q_value = reward + self.gamma * (1 - terminal) * estimated_value

        q_value_loss_1 = self.q_criterion_1(predicted_q_value_1, estimated_q_value)
        q_value_loss_2 = self.q_criterion_2(predicted_q_value_2, estimated_q_value)

        self.q_optimizer_1.zero_grad()
        q_value_loss_1.backward()
        self.q_optimizer_1.step()

        self.q_optimizer_2.zero_grad()
        q_value_loss_2.backward()
        self.q_optimizer_2.step()

        # policy network
        sampled_action, log_prob, _ = self.policy_network.sample(state)

        sampled_q_value_1, sampled_q_value_2 = self.q_network(state, sampled_action)
        sampled_q_value = torch.min(sampled_q_value_1, sampled_q_value_2)
        policy_loss = ((self.alpha * log_prob) - sampled_q_value).mean()

        self.policy_optimizer.zero_grad()
        # MEMORY LEAK in call to policy_loss.backward
        policy_loss.backward()
        self.policy_optimizer.step()

        # adjust temperature
        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()

            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()

            self.alpha = self.log_alpha.exp()
        else:
            alpha_loss = torch.tensor(0.).to(self.device)

        # (soft update) target q_value network
        if self.num_updates % self.target_update_interval == 0:
            for target_param, param in zip(self.target_q_network.parameters(), self.q_network.parameters()):
                # MEMORY LEAK in polyak averaging below
                target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data)

        index = self.num_updates - 1
        # self.loss_data[index] = [self.num_updates, q_value_loss_1.item(), q_value_loss_2.item(), policy_loss.item(), alpha_loss.item(), self.alpha.item()]

        del state, action, reward, next_state, terminal
        del predicted_q_value_1, predicted_q_value_2
        del next_state_sampled_action, next_state_log_prob
        del predicted_target_q_value_1, predicted_target_q_value_2
        del estimated_value, estimated_q_value
        del q_value_loss_1, q_value_loss_2
        del sampled_action, log_prob
        del sampled_q_value_1, sampled_q_value_2
        del sampled_q_value, policy_loss
        del alpha_loss
        del zipped
        del target_param, param
        del index

I am also providing code for my policy network.

class GaussianPolicyNetwork(nn.Module):
    """
    Agent policy.
    """

    LOG_SIG_MAX = 2
    LOG_SIG_MIN = -20
    epsilon = 1e-6

    def __init__(self, state_dim, action_dim, hidden_dim):
        """
        Initialize policy network.

        @param state_dim: int
            environment state dimension
        @param action_dim: int
            action dimension
        @param hidden_dim: int
            hidden layer dimension
        """

        super(GaussianPolicyNetwork, self).__init__()

        self.linear1 = nn.Linear(state_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)

        self.mean_linear = nn.Linear(hidden_dim, action_dim)
        self.log_std_linear = nn.Linear(hidden_dim, action_dim)

        self.apply(init_weights)

    def forward(self, state):
        """
        Calculate the mean and log standard deviation of the policy distribution.

        @param state: torch.float32 tensor with shape torch.Size([1, state_dim]) or torch.Size([batch_size, state_dim])
            state of the environment

        @return (mean, log_std):
        mean: torch.float32 tensor with shape torch.Size([1, action_dim]) or torch.Size([batch_size, action_dim])
            mean of the policy distribution
        log_std: torch.float32 tensor with shape torch.Size([1, action_dim]) or torch.Size([batch_size, action_dim])
            log standard deviation of the policy distribution
        """

        x = torch.relu(self.linear1(state))
        x = torch.relu(self.linear2(x))

        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=self.LOG_SIG_MIN, max=self.LOG_SIG_MAX)

        return mean, log_std

    def sample(self, state):
        """
        Sample an action using the reparameterization trick:
        - sample noise from a normal distribution,
        - multiply it with the standard deviation of the policy distribution,
        - add it to the mean of the policy distribution, and
        - apply the tanh function to the result.

        @param state: torch.float32 tensor with shape torch.Size([1, state_dim]) or torch.Size([batch_size, state_dim])
            state of the environment

        @return action, log_prob, mean, log_std:
        action: torch.float32 tensor with shape torch.Size([1, action_dim]) or torch.Size([batch_size, action_dim])
            (normalized) action selected by the agent
        log_prob: torch.float32 tensor with shape torch.Size([1, 1]) or torch.Size([batch_size, 1])
            log probability of the action
        mean: torch.float32 tensor with shape torch.Size([1, action_dim]) or torch.Size([batch_size, action_dim])
            mean of the policy distribution
        """

        mean, log_std = self.forward(state)  # torch.float32 torch.Size([batch_size, action_dim])
        std = log_std.exp()  # torch.float32 torch.Size([batch_size, action_dim])

        normal = Normal(mean, std)
        z = normal.rsample()
        action = torch.tanh(z)

        log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + self.epsilon)
        log_prob = log_prob.sum(1, keepdim=True)

        return action, log_prob, mean

I had a memory leak in a similar situation (to your first instance), making sarsd tensors from numpy arrays for rl. The solution for my instance was changing the dtype in a tensor constructor. I can’t offer a principled solution but you could try changing the constructor for terminal to:

 torch.as_tensor(terminal, dtype=torch.bool).unsqueeze(1).to(device=self.device)

and estimated q_value can be made like:

estimated_q_value = reward + self.gamma * (~terminal) * estimated_value

Boolean dtype should be preferred in this case right?