So I am training a DDQN to play connect four at the moment. At each state, the network predicts the action the best action and moves accordingly. The code looks basically like follows:
for epoch in range(num_epochs): for i in range(batch_size): while game is not finished: action = select_action(state) new_state = play_move(state, action) pred, target = get_pred_target(state, new_state, action) preds = torch.cat([preds, pred]) targets = torch.cat([targets, target)] loss = loss_fn(preds, targets) optimizer.zero_grad() loss.backward() optimizer.step()
While training, the network is getting a little bit better, but nowhere as good as I would expect. Thinking about it, I am wondering now, whether I have actually correctly implemented the loss.backward() call. The point is, I am saving all the predictions and targets for each move in the tensors preds and targets. However, I am not tracking the states, that have led to these predictions and targets. But isn’t that necessary for the backward propagation, or is this information somehow saved?
Thank you very much!