Get argmax as an int?

Hello!
I have some troubles getting the predicted value as an int. I am doing a DQN algorithm with the following model:

struct Model : torch::nn::Module {

torch::nn::Linear input_layer{ nullptr }, first_hidden_layer{ nullptr }, output_layer{ nullptr };

//Constructor
Model() {
    //construct and register the layers we want
    input_layer = register_module("input_layer", torch::nn::Linear(k_input_size, 64));
    first_hidden_layer = register_module("first_hidden_layer", torch::nn::Linear(64, 64));
    output_layer = register_module("output_layer", torch::nn::Linear(64, k_output_size));

}

//Feed forward through the network
torch::Tensor forward(torch::Tensor tensor_x) {
    tensor_x = torch::relu(input_layer->forward(tensor_x));
    tensor_x = torch::relu(first_hidden_layer->forward(tensor_x));
    tensor_x = torch::relu(output_layer->forward(tensor_x));

    return torch::log_softmax(tensor_x, /*dim*/1);
}

};

I then give it my input vector and I want it to return the index of wich of the output had the highest value. I try getting this with argmax but argmax returns a tensor and not a int?

int getAction(std::vector &input_vector) {
c10::IntArrayRef size({ k_input_size,1 });
torch::Tensor input_tensor = torch::from_blob(input_vector.data(), size);

auto DNN_out = model.forward(input_tensor);
int action = DNN_out.argmax(1); //Assuming dim 1 is the one I want here?

return action;
}

This only works if DNN_out.argmax(1) is a tensor having just one element (i.e., if DNN_out has a shape like {1, n}):

int action = DNN_out.argmax(1).item().toInt();
1 Like

thank you very much!