I have a few lines of code that are contributing to a growing memory leak (torch 1.6.0 and cuda 11.0). I don’t know how to fix them and I would appreciate some feedback.
They are all in one method: agent_update_network_parameters() (shown below). I have marked the lines that have a leak with ‘# MEMORY LEAK’. There are three lines, each with a leak that is contributing to increasing memory usage (1 MB per call).
@profile
def agent_update_network_parameters(self):
"""
Update the parameters for the NN(s).
Note: This is performed in the following order:
- value network
- both q value network's
- policy network
"""
self.num_updates += 1
state, action, reward, next_state, terminal = self.replay_buffer.sample(self.batch_size)
state = torch.FloatTensor(state).to(device=self.device)
action = torch.FloatTensor(action).to(device=self.device)
reward = torch.FloatTensor(reward).unsqueeze(1).to(device=self.device)
next_state = torch.FloatTensor(next_state).to(device=self.device)
# MEMORY LEAK in terminal (originally was numpy array of boolean values)
terminal = torch.FloatTensor(terminal).unsqueeze(1).to(device=self.device)
# q_value network
predicted_q_value_1, predicted_q_value_2 = self.q_network(state, action)
with torch.no_grad():
next_state_sampled_action, next_state_log_prob, _ = self.policy_network.sample(next_state)
predicted_target_q_value_1, predicted_target_q_value_2 = self.target_q_network(next_state, next_state_sampled_action)
estimated_value = torch.min(predicted_target_q_value_1, predicted_target_q_value_2) - self.alpha * next_state_log_prob
estimated_q_value = reward + self.gamma * (1 - terminal) * estimated_value
q_value_loss_1 = self.q_criterion_1(predicted_q_value_1, estimated_q_value)
q_value_loss_2 = self.q_criterion_2(predicted_q_value_2, estimated_q_value)
self.q_optimizer_1.zero_grad()
q_value_loss_1.backward()
self.q_optimizer_1.step()
self.q_optimizer_2.zero_grad()
q_value_loss_2.backward()
self.q_optimizer_2.step()
# policy network
sampled_action, log_prob, _ = self.policy_network.sample(state)
sampled_q_value_1, sampled_q_value_2 = self.q_network(state, sampled_action)
sampled_q_value = torch.min(sampled_q_value_1, sampled_q_value_2)
policy_loss = ((self.alpha * log_prob) - sampled_q_value).mean()
self.policy_optimizer.zero_grad()
# MEMORY LEAK in call to policy_loss.backward
policy_loss.backward()
self.policy_optimizer.step()
# adjust temperature
if self.automatic_entropy_tuning:
alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
self.alpha = self.log_alpha.exp()
else:
alpha_loss = torch.tensor(0.).to(self.device)
# (soft update) target q_value network
if self.num_updates % self.target_update_interval == 0:
for target_param, param in zip(self.target_q_network.parameters(), self.q_network.parameters()):
# MEMORY LEAK in polyak averaging below
target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data)
index = self.num_updates - 1
# self.loss_data[index] = [self.num_updates, q_value_loss_1.item(), q_value_loss_2.item(), policy_loss.item(), alpha_loss.item(), self.alpha.item()]
del state, action, reward, next_state, terminal
del predicted_q_value_1, predicted_q_value_2
del next_state_sampled_action, next_state_log_prob
del predicted_target_q_value_1, predicted_target_q_value_2
del estimated_value, estimated_q_value
del q_value_loss_1, q_value_loss_2
del sampled_action, log_prob
del sampled_q_value_1, sampled_q_value_2
del sampled_q_value, policy_loss
del alpha_loss
del zipped
del target_param, param
del index
I am also providing code for my policy network.
class GaussianPolicyNetwork(nn.Module):
"""
Agent policy.
"""
LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6
def __init__(self, state_dim, action_dim, hidden_dim):
"""
Initialize policy network.
@param state_dim: int
environment state dimension
@param action_dim: int
action dimension
@param hidden_dim: int
hidden layer dimension
"""
super(GaussianPolicyNetwork, self).__init__()
self.linear1 = nn.Linear(state_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.mean_linear = nn.Linear(hidden_dim, action_dim)
self.log_std_linear = nn.Linear(hidden_dim, action_dim)
self.apply(init_weights)
def forward(self, state):
"""
Calculate the mean and log standard deviation of the policy distribution.
@param state: torch.float32 tensor with shape torch.Size([1, state_dim]) or torch.Size([batch_size, state_dim])
state of the environment
@return (mean, log_std):
mean: torch.float32 tensor with shape torch.Size([1, action_dim]) or torch.Size([batch_size, action_dim])
mean of the policy distribution
log_std: torch.float32 tensor with shape torch.Size([1, action_dim]) or torch.Size([batch_size, action_dim])
log standard deviation of the policy distribution
"""
x = torch.relu(self.linear1(state))
x = torch.relu(self.linear2(x))
mean = self.mean_linear(x)
log_std = self.log_std_linear(x)
log_std = torch.clamp(log_std, min=self.LOG_SIG_MIN, max=self.LOG_SIG_MAX)
return mean, log_std
def sample(self, state):
"""
Sample an action using the reparameterization trick:
- sample noise from a normal distribution,
- multiply it with the standard deviation of the policy distribution,
- add it to the mean of the policy distribution, and
- apply the tanh function to the result.
@param state: torch.float32 tensor with shape torch.Size([1, state_dim]) or torch.Size([batch_size, state_dim])
state of the environment
@return action, log_prob, mean, log_std:
action: torch.float32 tensor with shape torch.Size([1, action_dim]) or torch.Size([batch_size, action_dim])
(normalized) action selected by the agent
log_prob: torch.float32 tensor with shape torch.Size([1, 1]) or torch.Size([batch_size, 1])
log probability of the action
mean: torch.float32 tensor with shape torch.Size([1, action_dim]) or torch.Size([batch_size, action_dim])
mean of the policy distribution
"""
mean, log_std = self.forward(state) # torch.float32 torch.Size([batch_size, action_dim])
std = log_std.exp() # torch.float32 torch.Size([batch_size, action_dim])
normal = Normal(mean, std)
z = normal.rsample()
action = torch.tanh(z)
log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + self.epsilon)
log_prob = log_prob.sum(1, keepdim=True)
return action, log_prob, mean