Torch.jit._fork is not working and all operators runs sequentially

i followed this link cpu_threading_torchscript_inference to try to enable inter op multi threading. In my test script, i called torch.set_num_interop_threads(8) to use all my cpu cores and runs 8 conv1d operators with torch.jit._fork. The result shows that torch.jit._fork is not working and all operators runs sequentially.

Duration the test, cpu usage is almost 100%, which means that multi threading is not actually enabled.

I got the warning “Access to a protected member _fork of a class” for torch.jit._fork, does this matter?

Any answers are appreciate and thank you for your help!

My core script

    def forward(self, x, threads, seq):
        iterval = seq // threads
        conv_res = []
        conv_threads = []
        start = time.time() * 1000
        for i in range(threads):
            start_inner = time.time() * 1000
            y = torch.jit._fork(self.compute, (x[:, :, i * iterval:(i + 1) * iterval]))
            print("fork %d cost %d ms" % (i, time.time() * 1000 - start_inner))
            conv_threads.append(y)
        print("fork totally cost %d ms" % (time.time() * 1000 - start))
        start = time.time() * 1000
        for i in range(threads):
            conv_res.append(torch.jit._wait(conv_threads[i]))
        print("wait cost %d ms" % (time.time() * 1000 - start))
        return conv_res

Test result

fork 0 cost 5 ms
fork 1 cost 5 ms
fork 2 cost 4 ms
fork 3 cost 5 ms
fork 4 cost 4 ms
fork 5 cost 5 ms
fork 6 cost 5 ms
fork 7 cost 5 ms
fork totally cost 41 ms
wait cost 0 ms

Full script

import time
import torch
import threading
import torch.nn as nn
from torch.nn.utils import weight_norm


class MyConvParallel(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.cell = nn.Conv1d(*args, **kwargs)
        self.cell.weight.data.normal_(0.0, 0.02)

    def compute(self, x):
        return self.cell(x)

    def forward(self, x, threads, seq):
        iterval = seq // threads
        conv_res = []
        conv_threads = []
        start = time.time() * 1000
        for i in range(threads):
            start_inner = time.time() * 1000
            y = torch.jit._fork(self.compute, (x[:, :, i * iterval:(i + 1) * iterval]))
            print("fork %d cost %d ms" % (i, time.time() * 1000 - start_inner))
            conv_threads.append(y)
        print("fork totally cost %d ms" % (time.time() * 1000 - start))
        start = time.time() * 1000
        for i in range(threads):
            conv_res.append(torch.jit._wait(conv_threads[i]))
        print("wait cost %d ms" % (time.time() * 1000 - start))
        return conv_res


def main():
    #print(*torch.__config__.show().split("\n"), sep="\n");exit(0)
    intro_threads = 1
    inter_threads = 8
    dim = 64
    kernels = 3
    seq = 80000
    torch.set_num_threads(intro_threads)
    torch.set_num_interop_threads(inter_threads)
    MyCell = MyConvParallel(dim, dim, kernel_size=kernels, stride=1)
    MyCell.eval()
    inputs = []
    iter = 1000
    for i in range(iter):
        inputs.append(torch.rand(1, dim, seq))

    start = time.time() * 1000
    for i in range(iter):
        print(i)
        y = MyCell(inputs[i], inter_threads, seq)
        #print(y)
    end = time.time() * 1000
    print('cost %d ms per iter\n' % ((end - start) / iter))


if __name__ == "__main__":
    main()

That only works in TorchScript, try MyCell=torch.jit.script(MyCell)

thanks @googlebot ! Problem resolved.