C++, Share a model between multiple threads and protect weights update

With the c++ api, how to share a model between multiple thread
and protect an update of the weight in a critical close?

The snippet below shows the situation, but I need help to make it work.
In the tests of pytorch, I can see some idiom like

auto module = std::make_shared<TestModule>();

Bu I don’t see how to adapt them.

# include <torch/torch.h>
# include <omp.h>

using namespace torch;
using namespace torch::nn;
using namespace torch::optim;

struct Net : Module 
{
  Linear fc1 = nullptr;
  Linear fc2 = nullptr;

  Net()
  {
    fc1 = register_module("fc1", Linear(10, 10));
    fc2 = register_module("fc2", Linear(10, 2);
  }

  Tensor forward(Tensor x) 
  {
    x = fc1->forward(x);
    x = relu(x);
    x = fc2->forward(x);
    x = sigmoid(x);
    return x;
  }
};

int main() 
{
  Net nn0;
  SGD opt(nn0.parameters(), SGDOptions(0.1).momentum(0.9));

  #pragma omp parallel
  {
    Net& nn = nn0;

    opt.zero_grad();

    auto x = rand({10});
    auto y = nn.forward(x);
    auto target = rand({10});
    auto loss = mse_loss(y, target);

    loss.backward();

    #pragma omp critical
    {
      opt.step(); // like this?
    }

  }
}

The code between opt.zero_grad() and opt.step() is executed several times by each thread in my true code.

Thanks.

1 Like