This model has
What is the proper way to write a multi headed model in C++?
Does the code below suffice or do I need to some some special module inside the torch::nn library?
std::tuple<torch::Tensor, torch::Tensor> Critic_Net : torch::nn::Module {
torch::Tensor next_state_batch__sampled_action;
// This network returns just one value
public:
Critic_Net() {
// Construct and register two Linear submodules.
lin1 = register_module("lin1", torch::nn::Linear(427, 42));
lin2 = register_module("lin2", torch::nn::Linear(42, 286));
head_1 = register_module("head_1", torch::nn::Linear(286, 1));
head_2 = register_module("head_2", torch::nn::Linear(286, 1));
lin1->to(device);
lin2->to(device);
head_1->to(device);
head_2->to(device);
}
torch::Tensor forward(torch::Tensor next_state_batch__sampled_action) {
auto h = next_state_batch__sampled_action;
h = torch::relu(lin1->forward(h));
h = torch::tanh(lin2->forward(h));
h1 = head_1->forward(h);
h2 = head_2->forward(h);
return {h1, h2};
}
torch::nn::Linear lin1{nullptr}, lin2{nullptr}, head_1{nullptr}, head_2{nullptr};
};