I have observed that when using torch.compile
to optimize a model, the performance significantly degrades during inference under torch.inference_mode
. In fact, it is even worse than the performance of the non-optimized model. However, when I place the compilation process within the context of torch.inference_mode
, the performance issues are resolved.
My resultI tested in torch==2.1.0+cu118
on an A10 GPU:
# 1. both are disabled, speed up 2.5X, it's great!
enable inference mode when compile: False
enable inference mode when benchmark: False
not compiled latency: 0.002047449684143066
compiled latency: 0.0008150367736816407, speed up: 2.5120948529652662
# 2. both are enabled, speed up 2.35X, great work!
enable inference mode when compile: True
enable inference mode when benchmark: True
not compiled latency: 0.0019193056106567383
compiled latency: 0.0008156064033508301, speed up: 2.3532252846121366
# 3. enable torch.inference_mode when benchmark, disabled when compiling
# the compiled model is slower than naive torch model ! help me!
# and the naive model's latency is still ~0.002
enable inference mode when compile: False
enable inference mode when benchmark: True
not compiled latency: 0.0019556352615356445
compiled latency: 0.010692658996582031, speed up: 0.18289513040309005
# 4. disable in benchmark and enable in compiling, the result is also disappointed
enable inference mode when compile: True
enable inference mode when benchmark: False
not compiled latency: 0.0020351423263549806
compiled latency: 0.011507730865478516, speed up: 0.17685001067066186
And here is the code to reproduce my result,
import torch
from torchvision.models import resnet18
import sys
from contextlib import nullcontext
# enable inference mode when compile or benchmark?
enable_when_compile = sys.argv[1] == "true"
enable_when_benchmark = sys.argv[2] == "true"
def _benchmark(
iters,
f,
context, # with torch.inference or not
*args,
**kwargs,
) -> float:
"""Estimates the average time duration for a single inference call in second
Returns:
estimated average time duration in second for a single inference call
"""
with context():
f(*args, **kwargs)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
with context():
start_event.record()
for _ in range(iters):
f(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
elapsed_time_s = start_event.elapsed_time(end_event) * 1.0e-3
avg_time_s = elapsed_time_s / iters
print("Estimated average time duration: {:.6f} s".format(avg_time_s))
return avg_time_s
class BenchmarkRunner(object):
def __init__(self, use_inference_mode: bool):
self.context = nullcontext if not use_inference_mode else torch.inference_mode
def __call__(self, iters, f, *args, **kwargs) -> float:
return _benchmark(iters, f, self.context, *args, **kwargs)
@torch.no_grad()
def run():
input = [torch.rand(8, 3, 224, 224).to(torch.device("cuda"), dtype=torch.float16)]
net = resnet18(pretrained=False).cuda().half()
net.eval()
context = nullcontext if not enable_when_compile else torch.inference_mode
compiled = torch.compile(net, mode="reduce-overhead", backend="inductor")
with context():
_ = compiled(*input)
latency_compiled = BenchmarkRunner(enable_when_benchmark)(10, compiled, *input)
latency_torch = BenchmarkRunner(enable_when_benchmark)(10, net, *input)
print(f"enable inference mode when compile: {enable_when_compile}")
print(f"enable inference mode when benchmark: {enable_when_benchmark}")
print(f"not compiled latency: {latency_torch}")
print(f"compiled latency: {latency_compiled}, speed up: {latency_torch / latency_compiled}")
if __name__ == "__main__":
run()