With the c++ api, how to share a model between multiple thread
and protect an update of the weight in a critical close?
The snippet below shows the situation, but I need help to make it work.
In the tests of pytorch, I can see some idiom like
auto module = std::make_shared<TestModule>();
Bu I don’t see how to adapt them.
# include <torch/torch.h>
# include <omp.h>
using namespace torch;
using namespace torch::nn;
using namespace torch::optim;
struct Net : Module
{
Linear fc1 = nullptr;
Linear fc2 = nullptr;
Net()
{
fc1 = register_module("fc1", Linear(10, 10));
fc2 = register_module("fc2", Linear(10, 2);
}
Tensor forward(Tensor x)
{
x = fc1->forward(x);
x = relu(x);
x = fc2->forward(x);
x = sigmoid(x);
return x;
}
};
int main()
{
Net nn0;
SGD opt(nn0.parameters(), SGDOptions(0.1).momentum(0.9));
#pragma omp parallel
{
Net& nn = nn0;
opt.zero_grad();
auto x = rand({10});
auto y = nn.forward(x);
auto target = rand({10});
auto loss = mse_loss(y, target);
loss.backward();
#pragma omp critical
{
opt.step(); // like this?
}
}
}
The code between opt.zero_grad()
and opt.step()
is executed several times by each thread in my true code.
Thanks.