I want to find some experiments that clearly show the benefit of torch.jit.script
.
It clearly improves the speed of inference, but seems not enough.
I simply build up experiment like this.
import torch
import torch.nn as nn
import torch.jit as jit
from torch import Tensor
from typing import Tuple
class QueryKeyValueUnoptimized(nn.Module):
def __init__(self):
super().__init__()
self.query = nn.Linear(512, 256)
self.key = nn.Linear(512, 256)
self.value = nn.Linear(512, 256)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
return self.query(x), self.key(x), self.value(x)
class QueryKeyValueOptimized(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(512, 768)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
x = self.linear(x)
return x[:, :256], x[:, 256:512], x[:, 512:]
net_unoptimized_nojit_cpu = QueryKeyValueUnoptimized()
net_unoptimized_jit_cpu = jit.script(net_unoptimized_nojit_cpu)
net_optimized_nojit_cpu = QueryKeyValueOptimized()
net_optimized_jit_cpu = jit.script(net_optimized_nojit_cpu)
device = torch.device("cpu")
x = torch.zeros(128, 512, device=device)
%timeit net_unoptimized_nojit_cpu(x)
%timeit net_unoptimized_jit_cpu(x)
%timeit net_optimized_nojit_cpu(x)
%timeit net_optimized_jit_cpu(x)
device = torch.device("cuda")
net_unoptimized_nojit_cuda = net_unoptimized_nojit_cpu.to(device)
net_unoptimized_jit_cuda = net_unoptimized_jit_cpu.to(device)
net_optimized_nojit_cuda = net_optimized_nojit_cpu.to(device)
net_optimized_jit_cuda = net_optimized_jit_cpu.to(device)
x = x.to(device)
%timeit net_unoptimized_nojit_cuda(x)
%timeit net_unoptimized_jit_cuda(x)
%timeit net_optimized_nojit_cuda(x)
%timeit net_optimized_jit_cuda(x)
net_unoptimized_cuda_jit = jit.script(QueryKeyValueUnoptimized().to(device))
net_optimized_cuda_jit = jit.script(QueryKeyValueOptimized().to(device))
%timeit net_unoptimized_cuda_jit(x)
%timeit net_optimized_cuda_jit(x)
And the result is as followed.
Manual optimization | JIT optimization | Device | Inference time (microsec) |
---|---|---|---|
X | X | CPU | 623 |
X | O | CPU | 588 |
O | X | CPU | 290 |
O | O | CPU | 277 |
X | X | GPU | 191 |
X | O | GPU | 170 |
O | X | GPU | 94.3 |
O | O | GPU | 97.5 |
I have three questions.
- Is there any other official or well-known performance benchmark result for JIT scripting?
- The result indicates that we have to optimize code manually although we perform jit operation to Module. Is it right?
- Is the fusing technique, which I manually do in this experiment, not supported in PyTorch?
Thanks,