Where does the learning actually happen in the Reinforcement Learning tutorial?


I will attempt to answer the question posed in the topic.

Looking at the tutorial code, I think this part is crucial:

# Compute V(s_{t+1}) for all next states.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()

It seems that every next_state_value where the non_final_mask is 0 is set to 0.0. So when the following line is executed:

# Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

the expected_state_action_values of those non_final_mask = 0 states are lower than the expected_state_action_values of the rest – even after the reward has been factored in. The gradient from the subsequent loss will shift the weights slightly away from the states whose following state will be terminal.

Is this how the reinforcement learning works here? Please let me know if I’m understanding the process correctly.

Thank you.