JIT benchmark and the future direction of this technique

I want to find some experiments that clearly show the benefit of torch.jit.script.

It clearly improves the speed of inference, but seems not enough.

I simply build up experiment like this.

import torch
import torch.nn as nn
import torch.jit as jit
from torch import Tensor
from typing import Tuple

class QueryKeyValueUnoptimized(nn.Module):
    def __init__(self):
        self.query = nn.Linear(512, 256)
        self.key = nn.Linear(512, 256)
        self.value = nn.Linear(512, 256)
    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        return self.query(x), self.key(x), self.value(x)
class QueryKeyValueOptimized(nn.Module):
    def __init__(self):
        self.linear = nn.Linear(512, 768)

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        x = self.linear(x)
        return x[:, :256], x[:, 256:512], x[:, 512:]

net_unoptimized_nojit_cpu = QueryKeyValueUnoptimized()
net_unoptimized_jit_cpu = jit.script(net_unoptimized_nojit_cpu)
net_optimized_nojit_cpu = QueryKeyValueOptimized()
net_optimized_jit_cpu = jit.script(net_optimized_nojit_cpu)

device = torch.device("cpu")
x = torch.zeros(128, 512, device=device)

%timeit net_unoptimized_nojit_cpu(x)
%timeit net_unoptimized_jit_cpu(x)
%timeit net_optimized_nojit_cpu(x)
%timeit net_optimized_jit_cpu(x)

device = torch.device("cuda")
net_unoptimized_nojit_cuda = net_unoptimized_nojit_cpu.to(device)
net_unoptimized_jit_cuda = net_unoptimized_jit_cpu.to(device)
net_optimized_nojit_cuda = net_optimized_nojit_cpu.to(device)
net_optimized_jit_cuda = net_optimized_jit_cpu.to(device)
x = x.to(device)

%timeit net_unoptimized_nojit_cuda(x)
%timeit net_unoptimized_jit_cuda(x)
%timeit net_optimized_nojit_cuda(x)
%timeit net_optimized_jit_cuda(x)

net_unoptimized_cuda_jit = jit.script(QueryKeyValueUnoptimized().to(device))
net_optimized_cuda_jit = jit.script(QueryKeyValueOptimized().to(device))

%timeit net_unoptimized_cuda_jit(x)
%timeit net_optimized_cuda_jit(x)

And the result is as followed.

Manual optimization JIT optimization Device Inference time (microsec)
X X CPU 623
X O CPU 588
O X CPU 290
O O CPU 277
X X GPU 191
X O GPU 170
O X GPU 94.3
O O GPU 97.5

I have three questions.

  • Is there any other official or well-known performance benchmark result for JIT scripting?
  • The result indicates that we have to optimize code manually although we perform jit operation to Module. Is it right?
  • Is the fusing technique, which I manually do in this experiment, not supported in PyTorch?