Defining net in C++ and copying state_dict

I have some basic questions regarding the definition of a neural net in cpp.
What is the difference between

struct DQN : torch::nn::Module
{
    torch::nn::Linear layer1;
    torch::nn::Linear layer2;
    torch::nn::Linear layer3;
   
    DQN(int in, int hidden, int out):   layer1(register_module("layer1",torch::nn::Linear(in,hidden))),
                                        layer2(register_module("layer2",torch::nn::Linear(hidden,hidden))),
                                        layer3(register_module("layer3",torch::nn::Linear(hidden,out))){}

  
    torch::Tensor forward(torch::Tensor x){
        x = torch::relu(layer1->forward(x));
        x = torch::relu(layer2->forward(x));
        x= layer3->forward(x);
        return x;
     }
};

and

struct DQN : torch::nn::Module
{
    torch::nn::Sequential fc_val;
      
    DQN(int in, int hidden, int out):  {
      
        fc_val = torch::nn::Sequential(
                                        torch::nn::Linear(in,hidden),
                                        torch::nn::ReLU(),
                                        torch::nn::Linear(hidden,hidden),
                                        torch::nn::ReLU(),
                                        torch::nn::Linear(hidden,hidden),
                                        torch::nn::ReLU(),
                                        torch::nn::Linear(hidden,out));
    }
  

    torch::Tensor forward (torch::Tensor x){
         x=fc_val->forward(x);
         return x;
    }

};

Also, when defining a policy and a tartget net for Deep Q Learning, I should copy the parameters from one net to the other.
In python:
target_net.load_state_dict(policy_net.state_dict())

Since these functions are not available in the C++ API, I should use

DQN policy_net(5,100,3);
DQN target_net(5,100,3);
std::stringstream stream;
torch::save(policy_net, stream);
torch::load(target_net, stream);

as suggested by https://github.com/pytorch/pytorch/issues/36577
But since policy_net and target_net are structs of type DQN, there is no << operator. How should I implement this?

Thanks!

Hi,

Is this the correct way to do it?

std::stringstream stream;
    torch::save(policy_net.layer1,stream);
    torch::load(target_net.layer1,stream);
    std::stringstream().swap(stream);
    torch::save(policy_net.layer2,stream);
    torch::load(target_net.layer2,stream);
    std::stringstream().swap(stream);
    torch::save(policy_net.layer3,stream);
    torch::load(target_net.layer3,stream);

And should I add this as a function to struct DQN?

operator<<(std::ostream &stream, const DQN &module){
???
}

For future reference, defining a torch module will create a shared pointer to the object, and then torch::save can be used to copy policy net to target net.

struct DQNImpl : torch::nn::Module
{
    torch::nn::Linear layer1;
    torch::nn::Linear layer2;
    torch::nn::Linear layer3;
   
    DQNImpl(int in, int hidden, int out):   layer1(register_module("layer1",torch::nn::Linear(in,hidden))),
                                        layer2(register_module("layer2",torch::nn::Linear(hidden,hidden))),
                                        layer3(register_module("layer3",torch::nn::Linear(hidden,out))){}

  
        torch::Tensor forward(torch::Tensor x){
        x = torch::relu(layer1->forward(x));
        x = torch::relu(layer2->forward(x));
        x= layer3->forward(x);
        return x;
     }
};
TORCH_MODULE(DQN);

int main(){
    DQN policy_net(5,10,3);
    DQN target_net(5,10,3);
    std::stringstream stream;
    torch::save(policy_net,stream);
    torch::load(target_net,stream);
    auto input = torch::ones(5);
    auto output = policy_net->forward(input);
}