Hey,
still being new to PyTorch, I am still a bit uncertain about ways of using inbuilt loss functions correctly. On PyTorch’s official website on loss functions, examples are provided where both so called inputs and target values are provided to a loss function. I assume, that the input tensor models the output of a network, such that loss functions compute the loss as a function of the difference between the target and the input values.
In my concrete case, I am working on an implementation of Reinforcement Learning (DQN). So, If I am not mistaken and assuming that the above is correct, applying a single (simplified) update to my network in each iteration of an update-function’s for-loop should look somewhat like that:
def updateNetwork():
mini_batch = self.constructMiniBatch()
for sample in mini_batch:
# compute q_values (predictions)
predicted_q_values = self.q_net(sample.current_state, sample.current_state_actions)
# compute estimated max target value of next state
discounted_future_rewards = discount_factor * self.target_net(sample.next_state, sample.next_state_actions)
# argmax
max_future_reward, _ = t.squeeze(discounted_future_rewards).max(0)
# sum things up
max_future_reward += sample.observed_reward
# first: target values = predictions; then only set particular q_value to actual target value (max_future_reward) for which we have an observed reward
target_values = predicted_q_values.clone()
target_values[0, action_chosen_in_the_past] = max_future_reward
# apply update - This should make (only) the predicted q_value for which we observed a concrete reward in the past more similar target value
self.q_net.optimizer.zero_grad()
loss = self.q_net.loss(predicted_q_values, target_values)
loss.backward()
self.q_net.optimizer.step()
If this is correct so far, I was wondering how to translate this cumbersome variant of sequential batch training in a more efficient form of batch training, avoiding the outer for-loop. I assume, this would lead to better performance due to better parallelization on the GPU.
Could I simply pass a batch of inputs through the network, get a batch of q-values returned, and then pass both the batch of predicted q_values and the batch of target_values to the loss function? Is that how easy it works?
By the way, I am aware that there is an official tutorial on DQNs on the PyTorch website, but I am still a bit uncertain about it.
Thanks a lot in advance!
Daniel