[libTorch] model initialization on multiple devices for parallel inference

I am using libTorch for inference. I have multiple GPU devices, and I use a thread per-device. Inference works as expected, except the initialization seems to only run sequentially.
Once the initialization is complete, the rest of the code runs concurrently as expected.
This might have been okay for smaller models, but with big models, each takes several minutes, so I am trying to make the model initialization as fast as possible.

Here is a minimal, reproducible example:

#include <torch/torch.h>
#include <spdlog/spdlog.h>


using namespace torch;
namespace nn = torch::nn;
const torch::Device DEVICE = torch::Device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); 

// a dummy model for demonstration
struct NetImpl : nn::Module {
    nn::Sequential layers;

    NetImpl(std::vector<int64_t> sizes, torch::Device device = DEVICE)
   : layers{ register_module("layers", torch::nn::Sequential()) } 
    {
        for (size_t i = 0; i < sizes.size() - 1; i++) {
            layers->push_back(nn::Linear(sizes[i], sizes[i + 1]));
            layers->push_back(nn::Functional(torch::relu));
        }
        this->to(device);
    }

    auto forward(Tensor x) -> Tensor {
        x = layers->forward(x);
        return x;
    }
};
TORCH_MODULE(Net);

struct Timer {

    std::string name;
    std::chrono::time_point<std::chrono::high_resolution_clock> start;

    Timer(std::string name="")
    : name {name}, start {std::chrono::high_resolution_clock::now()}
    {
        spdlog::info("Timer {} started", name);
    }

    double elapsed() {
        auto now = std::chrono::high_resolution_clock::now();
        return std::chrono::duration_cast<std::chrono::seconds>(now - start).count();
    }

    ~Timer() {
        spdlog::info("Timer {} ended: {:.3f}s", name, elapsed());
    }
};


int main() {
    spdlog::info("torch version {}", TORCH_VERSION);
    // deep network; FFN with a lot of layers to make it deep
    std::vector<int64_t> dims = { 
        1024, 4096, 8192, 16384, 8192, 4096, 1024, 512, 256, 512,
        1024, 4096, 8192, 16384, 8192, 4096, 1024, 512, 256, 512,
        1024, 4096, 8192, 16384, 8192, 4096, 1024, 512, 256, 512,
        1024, 4096, 8192, 16384, 8192, 4096, 1024, 512, 256, 512,
        1024, 4096, 8192, 16384, 8192, 4096, 1024, 512, 256, 512,
        };

    if (!torch::cuda::is_available()) {
        throw std::runtime_error("CUDA is not available");
    }
    std::vector<torch::Device> devices;
    for (auto i = 0; i < torch::cuda::device_count(); i++) {
        devices.push_back(torch::Device(torch::kCUDA, i));
    }
    { // scope for timer 
        int n_threads = devices.size();
        Timer timer(fmt::format("[{}-threaded initializer]", n_threads));
        std::vector<std::jthread> threads;
        for (int i = 0; i < n_threads; i++) {
            auto t = std::jthread([i, &dims, &devices] {
                auto device = devices[i];
                Timer timer(fmt::format("{}", device.str()));
                auto model = Net(dims, device);
            });
            threads.push_back(std::move(t));
        }
    }
    return 0;
}

With a single GPU, i.e. CUDA_VISIBLE_DEVICES=0

[250108 04:12:39|t1753841][info] Timer [1-threaded initializer] started
[250108 04:12:39|t1753854][info] Timer cuda:0 started
[250108 04:12:53|t1753854][info] Timer cuda:0 ended: 14.000s
[250108 04:12:53|t1753841][info] Timer [1-threaded initializer] ended: 14.000s

Now, with CUDA_VISIBLE_DEVICES=0,1, the time is almost doubled

[250108 04:13:02|t1754149][info] Timer [2-threaded initializer] started
[250108 04:13:02|t1754163][info] Timer cuda:0 started
[250108 04:13:02|t1754164][info] Timer cuda:1 started
[250108 04:13:26|t1754164][info] Timer cuda:1 ended: 24.000s
[250108 04:13:27|t1754163][info] Timer cuda:0 ended: 24.000s
[250108 04:13:27|t1754149][info] Timer [2-threaded initializer] ended: 24.000s

And with CUDA_VISIBLE_DEVICES=0,1,2,3, the pattern continues:

[250108 04:14:04|t1754791][info] Timer [4-threaded initializer] started
[250108 04:14:04|t1754795][info] Timer cuda:0 started
[250108 04:14:04|t1754796][info] Timer cuda:1 started
[250108 04:14:04|t1754797][info] Timer cuda:2 started
[250108 04:14:04|t1754798][info] Timer cuda:3 started
[250108 04:14:52|t1754796][info] Timer cuda:1 ended: 47.000s
[250108 04:14:52|t1754795][info] Timer cuda:0 ended: 48.000s
[250108 04:14:58|t1754797][info] Timer cuda:2 ended: 54.000s
[250108 04:14:58|t1754798][info] Timer cuda:3 ended: 54.000s
[250108 04:14:58|t1754791][info] Timer [4-threaded initializer] ended: 54.000s

Finally, with all 8 devices:

[250108 04:15:50|t1755936][info] Timer [8-threaded initializer] started
[250108 04:15:50|t1755959][info] Timer cuda:0 started
[250108 04:15:50|t1755960][info] Timer cuda:1 started
[250108 04:15:50|t1755961][info] Timer cuda:2 started
[250108 04:15:50|t1755962][info] Timer cuda:3 started
[250108 04:15:50|t1755963][info] Timer cuda:4 started
[250108 04:15:50|t1755964][info] Timer cuda:5 started
[250108 04:15:50|t1755965][info] Timer cuda:6 started
[250108 04:15:50|t1755966][info] Timer cuda:7 started
[250108 04:17:23|t1755960][info] Timer cuda:1 ended: 92.000s
[250108 04:17:23|t1755965][info] Timer cuda:6 ended: 93.000s
[250108 04:17:24|t1755964][info] Timer cuda:5 ended: 93.000s
[250108 04:17:24|t1755959][info] Timer cuda:0 ended: 94.000s
[250108 04:17:24|t1755963][info] Timer cuda:4 ended: 94.000s
[250108 04:17:25|t1755966][info] Timer cuda:7 ended: 94.000s
[250108 04:17:25|t1755961][info] Timer cuda:2 ended: 95.000s
[250108 04:17:28|t1755962][info] Timer cuda:3 ended: 97.000s
[250108 04:17:28|t1755936][info] Timer [8-threaded initializer] ended: 97.000s

I can’t see where in NetImpl or nn::LinearImpl the locking is enforcing sequential executation.
It looks like some internal API (ATen/C10) is at play and I am clueless how to resolve it. How to improve the parallelization in my case?
Alternatively, I just want to initialize replicate/clone models for all devices as fast as possible.

I tried Cloneable, but even this did NOT reduce time for model replication/cloning

struct NetImpl: public nn::Cloneable<NetImpl> {
  std::vector<int64_t> sizes;
  nn::Sequential layers;

  NetImpl(std::vector<int64_t> sizes)
  : sizes {sizes}
  {
      NetImpl::reset();
  }
  
  void reset() override {
     layers = register_module("layers", torch::nn::Sequential());
     for (size_t i = 0; i < sizes.size() - 1; i++) {
          layers->push_back(nn::Linear(sizes[i], sizes[i + 1]));
          layers->push_back(nn::Functional(torch::relu));
      }
  }

  Tensor forward(Tensor x) {
      x = layers->forward(x);
      return x;
  }
};
TORCH_MODULE(Net);

And…

// create model once
auto cpu_model = Net(dims);
// ... inside each thread; one-per device
auto model = cpu_model->clone(device);

This should be taking 2x time, where 1x for CPU and then 1x per device. But, for 8 GPUs, it takes (1+8) = 9x the time.