Hi guys, after reading some stuff about autograd i still can not explain how it works in detail.
In the example (Reinforcement Learning (DQN) Tutorial — PyTorch Tutorials 2.1.1+cu121 documentation) you can see, that there are 2 tensors going into the loss function:
def optimize_model():
# [parts removed]
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(state_batch).gather(1, action_batch)
# Compute V(s_{t+1}) for all next states.
# Expected values of actions for non_final_next_states are computed based
# on the "older" target_net; selecting their best reward with max(1)[0].
# This is merged based on the mask, such that we'll have either the expected
# state value or 0 in case the state was final.
next_state_values = torch.zeros(BATCH_SIZE, device=device)
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
# Compute the expected Q values
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
# Compute Huber loss
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
# Optimize the model
optimizer.zero_grad()
loss.backward()
for param in policy_net.parameters():
param.grad.data.clamp_(-1, 1)
optimizer.step()
state_action_values are created with gather, resulting in a grad_fn=SqueezeBackward1, which keeps track of all operations.
next_state_values is detached(), expected_state_action_values is not.
(Q1) Which tensors have to be included, which not and why?
(Q2) Does it make sense to detach() expected_state_action_values, why or why not?
And another example for Q3:
q_values = self.DQN.predict(self.c_states).gather(1, self.actions.unsqueeze(1)).squeeze(1)
dqn_next = self.DQN.predict(self.n_states)
q_action = torch.argmax(dqn_next, dim=1)
tar_next = self.TAR.predict(self.n_states).detach().gather(1, q_action.unsqueeze(1)).squeeze(1)
q_target = self.rewards + (self.GAMMA * tar_next * self.dones)
loss = self.DQN.fit(current_q=q_values, expected_q=q_target)
def fit(self, current_q, expected_q):
loss = self.loss_func(current_q, expected_q)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return(loss.item())
(Q3) Can i detach() tar_next or q_target and fit(q_values, q_target) or do i need to remove the detach() here, because the information’s coming from gather (on self.TAR.predict) are needed?
Many thanks!