I am implementing dqn with images as states…and for some reasons the RAM usage keeps increasing despite a fixed replay memory size, and after I delete the variables. Can anyone tell me what’s wrong? Thanks a lot!
def learn(self):
self.iter_counter += 1
if(len(self.memory) < self.batch_size):
return
#Random transition batch is taken from experience replay memory
transitions = self.memory.sample(self.batch_size)
batch_state = []
batch_action = []
batch_reward = []
batch_state_next_state = []
batch_done = []
for t in transitions:
bs, ba, br, bsns, bd = t
batch_state.append(transform_img_for_model(bs))
batch_action.append(ba)
batch_reward.append(br)
batch_state_next_state.append(transform_img_for_model(bsns))
batch_done.append(bd)
batch_state = Variable(torch.stack(batch_state).cuda(async=True))
batch_action = torch.FloatTensor(batch_action).unsqueeze_(0)
batch_action = batch_action.view(batch_action.size(1), -1)
batch_action = Variable(batch_action.cuda(async=True))
batch_reward = torch.FloatTensor(batch_reward).unsqueeze_(0)
batch_reward = batch_reward.view(batch_reward.size(1), -1)
batch_reward = Variable(batch_reward.cuda(async=True))
batch_next_state = Variable(torch.stack(batch_state_next_state).cuda(async=True))
# current Q values are estimated by NN for all actions
current_q_values = self.evaluate_net(batch_state).gather(1, batch_action.long())
# expected Q values are estimated from actions which gives maximum Q value
max_next_q_values = self.target_net(batch_next_state).detach().max(1)[0]
max_next_q_values = max_next_q_values.unsqueeze_(0)
max_next_q_values = max_next_q_values.view(max_next_q_values.size(1), -1)
expected_q_values = batch_reward + (self.gamma * max_next_q_values)
# loss is measured from error between current and newly expected Q values
loss = self.loss_function(current_q_values, expected_q_values)
# backpropagation of loss to NN
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
#free variables
del batch_state, batch_action, batch_reward, batch_next_state, loss, transitions, current_q_values, max_next_q_values, expected_q_values
#for obj in gc.get_objects():
# if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
# print(type(obj), obj.size())
if(self.iter_counter % self.iter_update_target == 0):
self.target_net.load_state_dict(self.evaluate_net.state_dict())