Is inference thread-safe?

I know that code below is thread-safe (Many modules, many threads.):

void foo(const std::vector<std::shared_ptr<torch::jit::script::Module>>& modules, int idx)
{
	torch::NoGradGuard no_grad;
	torch::Tensor X = torch::rand({ 1, 10, 1, 300 });
	auto out = modules[idx]->forward({ X });
}

void first_inference()
{
	using namespace torch::jit;
	const std::string model_path{ "model.pth" };

	std::vector<std::shared_ptr<torch::jit::script::Module>> modules;
	modules.reserve(10);
	for (int i = 0; i < 10; i++) {
		modules.push_back(std::make_shared<script::Module>(load(model_path)));
	}

	std::vector<std::thread> vec;
	for (int i = 0; i < 10; i++) {
		std::thread t(foo, std::ref(modules), i);
		vec.push_back(std::move(t));
	}

	for (auto& t : vec)
		t.join();
}

What about the following code (One module, many threads.)? Is it thread safe?
I checked that the forward doesn’t have a global lock. The program doesn’t crash and has the correct results.

void bar(const std::shared_ptr<torch::jit::script::Module> module)
{
	torch::NoGradGuard no_grad;
	torch::Tensor X = torch::rand({ 1, 10, 1, 300 });
	auto out = module->forward({ X });
}

void second_inference()
{
	using namespace torch::jit;
	const std::string model_path{ "model.pth" };

	std::shared_ptr<script::Module> module{ std::make_shared<script::Module>(load(model_path)) };

	std::vector<std::thread> vec;
	for (int i = 0; i < 10; i++) {
		std::thread t(bar, std::ref(module));
		vec.push_back(std::move(t));
	}

	for (auto& t : vec)
		t.join();
}

Hi,

In general, all the objects in pytorch are thread safe to read. But are not to write into.
If your Module doesn’t write into shared structure, then it should work just fine yes.

I think my torchscript model doesn’t change the internal state of the module. It is used only for forward and doesn’t calculate gradients.

Does this mean that the inference will work correctly for one module in many threads, if I load a model, for example, torchscript resnet using an official pytorch url?

Is this true for the GPU?

Having multiple modules or using the GPU doesn’t change anything there yes.

For the jit I’m less sure. But I would expect inference to have the same rules as regular inference.

1 Like

yes, as long as you are not mutating elements of the module state in your forward pass, inference is thread safe in TorchScript

3 Likes