Your problem is caused by dividing the result of ‘register_parameter’ → it creates different tensor which is then assigned to member variable (thus can’t be affected from outside the module by editing ‘.parameters()’)
the part → “/numHid”:
first = register_parameter("inputW", torch::rand({numIn, numHid}))/numHid;
Full example using approach from: How to copy parameters:
struct NetImpl : torch::nn::Module {
NetImpl(int numIn, int numOut, int numHid, const size_t hid_count = 1) {
assert(hid_count >= 1);
first = register_parameter("inputW", torch::rand({ numIn, numHid }));
middle = new torch::Tensor[hid_count - 1];
for (int i = 1; i != hid_count; i++)
middle[i] = register_parameter("hidW" + std::to_string(i), torch::rand({ numHid, numHid }));
last = register_parameter("outputW", torch::rand({ numHid, numOut }));
h_c = hid_count;
n_h = numHid;
}
torch::Tensor forward(torch::Tensor input)
{
torch::Tensor output_layer, h;
h = (torch::mm(input, first));
for (int i = 1; i != h_c; i++)
h = torch::sigmoid(torch::mm(h, middle[i]));
output_layer = (torch::mm(h, last));
return output_layer;
}
torch::Tensor first, last, *middle;
size_t h_c, n_h;
};
TORCH_MODULE(Net);
void main(int argc, char** argv)
{
try
{
const int input_nodes = 20, output_nodes = 10, hidden_nodes = 2, hidden_count = 1;
Net nn0(input_nodes, output_nodes, hidden_nodes, hidden_count);
Net nn1(input_nodes, output_nodes, hidden_nodes, hidden_count);
{
torch::autograd::GradMode::set_enabled(false);
auto src = nn0->named_parameters(true /*recurse*/);
auto dst = nn1->named_parameters(true /*recurse*/);
for (auto& val : src)
{
auto name = val.key();
auto* t = dst.find(name);
if (t != nullptr)
{
t->copy_(val.value());
}
}
torch::autograd::GradMode::set_enabled(true);
}
const torch::Tensor input = torch::ones({ 1,input_nodes });
std::cout << "Diff: " << nn0->forward(input) - nn1->forward(input) << std::endl;
}
catch (const c10::Error& e)
{
std::cout << e.what() << std::endl;
}
catch (const std::runtime_error& e)
{
std::cout << e.what() << std::endl;
}
system("PAUSE");
}