Internal mutex in C++ quantized backend

Hello! I have some deep learning model which i want to transfer to C++ and make parallel threaded inference. My use-case requires all threads to have its own model replica and each thread must execute model in one core.

Here is python script

import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
import tqdm
import argparse
import torch
import torch.nn.quantized
import torch.quantization


def make_fused_linear(in_features: int, out_features: int):
    return torch.quantization.fuse_modules(
        torch.nn.Sequential(
            torch.nn.Linear(in_features=in_features, out_features=out_features),
            torch.nn.ReLU(inplace=True)
        ),
        modules_to_fuse=['0', '1']
    )


class FeedforwardModel(torch.nn.Module):
    def __init__(self, features):
        super(FeedforwardModel, self).__init__()
        self._net = torch.nn.Sequential(
            make_fused_linear(features, 90),
            make_fused_linear(90, 90),
            make_fused_linear(90, 90),
            make_fused_linear(90, 90),
            make_fused_linear(90, 90),
            make_fused_linear(90, 90),
        )
        self._final = torch.nn.Linear(90, 50)
        self._quant = torch.quantization.QuantStub()
        self._dequant = torch.quantization.DeQuantStub()

    def forward(self, x: torch.Tensor):
        x = self._quant(x)
        x = self._final(self._net(x))
        x = self._dequant(x)
        return x


def timeit_model(model, *inputs):
    for _ in tqdm.trange(10000000000000):
        with torch.no_grad():
            model(*inputs)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-q", action="store_true", help="Use quantized model")
    parser.add_argument("-b", default=1, help="Batch size", type=int)

    torch.set_num_interop_threads(1)
    torch.set_num_threads(1)

    args = parser.parse_args()
    use_quantized = args.q
    batch_size = args.b

    in_features = 40 * 64  # new user model with 40 queries
    inputs = torch.rand(batch_size, in_features)

    with torch.no_grad():
        if not use_quantized:
            model = FeedforwardModel(in_features)
            model.eval()
            traced_script_module = torch.jit.trace(model, inputs)
            traced_script_module.save("model.torch")

            timeit_model(traced_script_module, inputs)
        else:
            model = FeedforwardModel(in_features)
            model.eval()
            model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
            torch.quantization.prepare(model, inplace=True)

            model(inputs)

            torch.quantization.convert(model, inplace=True)

            traced_script_module = torch.jit.trace(model, inputs)
            traced_script_module.save("quantized_model.torch")

            timeit_model(traced_script_module, inputs)

And corresponding C++ code

#include <iostream>
#include <future>
#include <torch/all.h>
#include <torch/script.h> // One-stop header.


int main(int argc, const char* argv[]) {
    // WARNING! this does not work for quantized model! For quantized, only setting
    // MKL_NUM_THREADS=1 and OMP_NUM_THREADS=1 work! We are investigating this issue with torch guys
    torch::set_num_threads(1);
    at::set_num_interop_threads(1);

    auto model_path = argv[1];
    auto num_threads = std::stoi( argv[2] );

    torch::Tensor inputs = torch::zeros({1, 40 * 64}, torch::kFloat32);

    std::vector<std::future<void>> futures;
    for (auto i = 0; i < num_threads; i++) {
        futures.emplace_back(std::move(std::async(std::launch::async, [model_path, inputs]() {
            torch::NoGradGuard torch_guard;
            auto model = torch::jit::load(model_path);
            model.eval();
            auto thread_inputs = inputs.clone();

            while (true) {
                model.forward({thread_inputs});
            }
        })));
    }

    for (auto &f : futures) { f.get(); }

    return 0;
}

The first issue is written in C++ comments: setting num threads programmatically does not work with quantized backend.

The second, more severe issue - when i launch code in 40 threads on 40-core machine, floating-point model parallels perfectly (as it should), quantized model stucks with some mutex. This is easily seen either in htop (cpu cores spend time in kernel syscalls) and strace. strace says that quantized model calls futex very frequently. Floating point model does no syscalls after all threads started.

Can you help me to get rid of this lock? Maybe i’m doing something wrong?

I think quantized backend uses fbgemm which uses MKL_NUM_THREADS=1 and OMP_NUM_THREADS=1 to control the threading. cc @raghuramank100 @jianyuhuang for the second issue.

Are you using a recent version of PyTorch (and compiled with C++14?). We made a recent improvement in the quantized backend library (fbgemm) to make use of read-write lock from C++14 and these are not supported on MacOS due to issues. Are you running it on MacOS?

Thank you!

I am already using these environment variables.

Thank you!

I am running code on linux server with 46-core processor. Unfortunately, getting fresh latest libtorch did not help.

I will try profile code and figure out which part of code makes issue.

Thanks. Profiling would be of great help.