So I am training a DDQN to play connect four at the moment. At each state, the network predicts the action the best action and moves accordingly. The code looks basically like follows:
for epoch in range(num_epochs):
for i in range(batch_size):
while game is not finished:
action = select_action(state)
new_state = play_move(state, action)
pred, target = get_pred_target(state, new_state, action)
preds = torch.cat([preds, pred])
targets = torch.cat([targets, target)]
loss = loss_fn(preds, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
While training, the network is getting a little bit better, but nowhere as good as I would expect. Thinking about it, I am wondering now, whether I have actually correctly implemented the loss.backward() call. The point is, I am saving all the predictions and targets for each move in the tensors preds and targets. However, I am not tracking the states, that have led to these predictions and targets. But isn’t that necessary for the backward propagation, or is this information somehow saved?
Thank you very much!