I want to change a value in a tensor to a specific value. The current way I try to do it: next_Q[replay_memory.at(trainingIDX).action].item() = reward;
My entire training loop and some output from it can be seen below.
void AiModel::train(Model &model_to_train,
torch::optim::Optimizer &optimizer,
std::vector<AiModel::Experience> &replay_memory) {
model_to_train.train();
//traning loop
for (size_t trainingIDX = 0; trainingIDX < replay_memory.size(); ++trainingIDX) {
float reward;
torch::Tensor next_state;
torch::Tensor next_Q;
torch::Tensor current_state = torch::from_blob(replay_memory.at(trainingIDX).current_state.data(), { k_input_size_ }).to(*device);
if (std::isnan(replay_memory.at(trainingIDX).next_state.at(0))){
//If next state is terminal:
next_state = torch::from_blob(replay_memory.at(trainingIDX).current_state.data(), { k_input_size_ }).to(*device);
next_Q = model_to_train.forward(next_state).detach();
reward = replay_memory.at(trainingIDX).reward;
next_Q[replay_memory.at(trainingIDX).action].item() = reward;
} else {
next_state = torch::from_blob(replay_memory.at(trainingIDX).next_state.data(), { k_input_size_ }).to(*device);
next_Q = model_to_train.forward(next_state).detach();
std::cout << "next_Q: \n" << next_Q << std::endl;
reward = replay_memory.at(trainingIDX).reward + k_gamma_ * next_Q.max().item().toFloat();
reward = 1;
std::cout << "reward: " << reward << std::endl;
next_Q[replay_memory.at(trainingIDX).action].item() = reward;
std::cout << "next_Q after: \n" << next_Q << std::endl;
}
optimizer.zero_grad();
auto current_Q = model_to_train.forward(current_state);
std::cout << "current_Q: \n" << current_Q << std::endl;
auto loss = torch::mse_loss(current_Q, next_Q);
std::cout << "loss: " << loss << std::endl;
loss.backward();
optimizer.step();
}//For traning loop
} // train-function
Which gives me the output:
next_Q:
-0.0558
0.0243
0.1972
-0.1091
0.0342
[ CPUFloatType{5} ]
reward: 1
next_Q after:
-0.0558
0.0243
0.1972
-0.1091
0.0342
[ CPUFloatType{5} ]
current_Q:
-0.0543
0.0240
0.1965
-0.1083
0.0341
[ CPUFloatType{5} ]
loss: 6.94254e-07
[ CPUFloatType{} ]