`torch.compile` + `torch.no_grad` not working for Mask R-CNN

Greetings,

I have been trying to get torch.compile to work with a Mask R-CNN model, but have not been able to do so in combination with gradient disabling. An error also occurs if I use torch.inference_mode instead of torch.no_grad. The following snippet should reproduce the error:

import torch
import torchvision

device = "cuda:0"

x = torch.rand(1, 3, 800, 800, device=device)

model = torchvision.models.detection.maskrcnn_resnet50_fpn()
model = model.to(device)
model = torch.compile(model)
model.eval()

# This works
# _ = model(x)

# This does not
with torch.no_grad():
    _ = model(x)

The output:

[2023-03-21 17:27:03,307] torch._inductor.utils: [WARNING] DeviceCopy in input program
[2023-03-21 17:27:03,307] torch._inductor.utils: [WARNING] DeviceCopy in input program
[2023-03-21 17:27:03,307] torch._inductor.utils: [WARNING] DeviceCopy in input program
[2023-03-21 17:27:03,307] torch._inductor.utils: [WARNING] DeviceCopy in input program
[2023-03-21 17:27:03,308] torch._inductor.utils: [WARNING] DeviceCopy in input program
/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:90: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
[2023-03-21 17:27:06,965] torch._inductor.graph: [ERROR] Error from lowering
Traceback (most recent call last):
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/graph.py", line 333, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 225, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 1431, in convolution
    ir.Convolution.create(
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/ir.py", line 3244, in create
    return Convolution(
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/ir.py", line 3088, in __init__
    super().__init__(layout, inputs, constant_args)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/ir.py", line 2827, in __init__
    super().__init__(None, layout, self.unwrap_storage(inputs), constant_args)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/ir.py", line 2378, in unwrap_storage
    assert isinstance(x, (Buffer, ReinterpretView)), x
AssertionError: Pointwise(
  'cuda',
  torch.float32,
  tmp0 = load(buf0, i3 + 14 * i2 + 196 * i1 + 50176 * i0)
  tmp1 = load(arg1_1, i1)
  tmp2 = tmp0 + tmp1
  tmp3 = relu(tmp2)
  return tmp3
  ,
  ranges=[0, 256, 14, 14],
  origins={arg0_1, relu, convolution, arg1_1, arg12_1}
)
Traceback (most recent call last):
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/graph.py", line 333, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 225, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 1431, in convolution
    ir.Convolution.create(
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/ir.py", line 3244, in create
    return Convolution(
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/ir.py", line 3088, in __init__
    super().__init__(layout, inputs, constant_args)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/ir.py", line 2827, in __init__
    super().__init__(None, layout, self.unwrap_storage(inputs), constant_args)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/ir.py", line 2378, in unwrap_storage
    assert isinstance(x, (Buffer, ReinterpretView)), x
AssertionError: Pointwise(
  'cuda',
  torch.float32,
  tmp0 = load(buf0, i3 + 14 * i2 + 196 * i1 + 50176 * i0)
  tmp1 = load(arg1_1, i1)
  tmp2 = tmp0 + tmp1
  tmp3 = relu(tmp2)
  return tmp3
  ,
  ranges=[0, 256, 14, 14],
  origins={arg0_1, relu, convolution, arg1_1, arg12_1}
)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/__init__.py", line 1390, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
    return aot_autograd(
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
    return inner_compile(
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 176, in compile_fx_inner
    graph.run(*example_inputs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/graph.py", line 194, in run
    return super().run(*args)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/graph.py", line 405, in run_node
    result = self.call_function(n.target, args, kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_inductor/graph.py", line 337, in call_function
    raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: AssertionError: Pointwise(
  'cuda',
  torch.float32,
  tmp0 = load(buf0, i3 + 14 * i2 + 196 * i1 + 50176 * i0)
  tmp1 = load(arg1_1, i1)
  tmp2 = tmp0 + tmp1
  tmp3 = relu(tmp2)
  return tmp3
  ,
  ranges=[0, 256, 14, 14],
  origins={arg0_1, relu, convolution, arg1_1, arg12_1}
)
  target: aten.convolution.default
  args[0]: TensorBox(StorageBox(
    Pointwise(
      'cuda',
      torch.float32,
      tmp0 = load(buf0, i3 + 14 * i2 + 196 * i1 + 50176 * i0)
      tmp1 = load(arg1_1, i1)
      tmp2 = tmp0 + tmp1
      tmp3 = relu(tmp2)
      return tmp3
      ,
      ranges=[0, 256, 14, 14],
      origins={arg0_1, relu, convolution, arg1_1, arg12_1}
    )
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg2_1', layout=FixedLayout('cuda', torch.float32, size=[256, 256, 3, 3], stride=[2304, 9, 3, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg3_1', layout=FixedLayout('cuda', torch.float32, size=[256], stride=[1]))
  ))
  args[3]: [1, 1]
  args[4]: [1, 1]
  args[5]: [1, 1]
  args[6]: False
  args[7]: [0, 0]
  args[8]: 1

While executing %convolution_1 : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%relu, %arg2_1, %arg3_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), kwargs = {})
Original traceback:
  File "/tmp/a/env/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py", line 804, in <graph break in forward>
    mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "nograd_compile_example.py", line 18, in <module>
    _ = model(x)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py", line 83, in forward
    images, targets = self.transform(images, targets)
  File "/tmp/a/env/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py", line 101, in <graph break in forward>
    features = self.backbone(images.tensors)
  File "/tmp/a/env/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py", line 104, in <graph break in forward>
    proposals, proposal_losses = self.rpn(images, features, targets)
  File "/tmp/a/env/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py", line 105, in <graph break in forward>
    detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py", line 761, in forward
    box_features = self.box_roi_pool(features, proposals, image_shapes)
  File "/tmp/a/env/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py", line 775, in <graph break in forward>
    boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
  File "/tmp/a/env/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py", line 804, in <graph break in forward>
    mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 541, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/tmp/a/env/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: AssertionError: Pointwise(
  'cuda',
  torch.float32,
  tmp0 = load(buf0, i3 + 14 * i2 + 196 * i1 + 50176 * i0)
  tmp1 = load(arg1_1, i1)
  tmp2 = tmp0 + tmp1
  tmp3 = relu(tmp2)
  return tmp3
  ,
  ranges=[0, 256, 14, 14],
  origins={arg0_1, relu, convolution, arg1_1, arg12_1}
)
  target: aten.convolution.default
  args[0]: TensorBox(StorageBox(
    Pointwise(
      'cuda',
      torch.float32,
      tmp0 = load(buf0, i3 + 14 * i2 + 196 * i1 + 50176 * i0)
      tmp1 = load(arg1_1, i1)
      tmp2 = tmp0 + tmp1
      tmp3 = relu(tmp2)
      return tmp3
      ,
      ranges=[0, 256, 14, 14],
      origins={arg0_1, relu, convolution, arg1_1, arg12_1}
    )
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg2_1', layout=FixedLayout('cuda', torch.float32, size=[256, 256, 3, 3], stride=[2304, 9, 3, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg3_1', layout=FixedLayout('cuda', torch.float32, size=[256], stride=[1]))
  ))
  args[3]: [1, 1]
  args[4]: [1, 1]
  args[5]: [1, 1]
  args[6]: False
  args[7]: [0, 0]
  args[8]: 1

While executing %convolution_1 : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%relu, %arg2_1, %arg3_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1), kwargs = {})
Original traceback:
  File "/tmp/a/env/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py", line 804, in <graph break in forward>
    mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)


Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True


I’ve tried the same test using a different model (torchvision.models.resnet50), and in
this case, the error does not happen. So it could perhaps be specific to the Mask R-CNN
model.

Environment used for testing:

$ python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 2.0.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.26.0
Libc version: glibc-2.31

Python version: 3.8.10 (default, Mar 13 2023, 10:26:41)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-67-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060 Laptop GPU
Nvidia driver version: 525.85.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   39 bits physical, 48 bits virtual
CPU(s):                          20
On-line CPU(s) list:             0-19
Thread(s) per core:              1
Core(s) per socket:              14
Socket(s):                       1
NUMA node(s):                    1
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           154
Model name:                      12th Gen Intel(R) Core(TM) i7-12700H
Stepping:                        3
CPU MHz:                         2700.000
CPU max MHz:                     4700,0000
CPU min MHz:                     400,0000
BogoMIPS:                        5376.00
Virtualization:                  VT-x
L1d cache:                       336 KiB
L1i cache:                       224 KiB
L2 cache:                        8,8 MiB
NUMA node0 CPU(s):               0-19
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req umip pku ospke waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize arch_lbr flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.24.2
[pip3] pytorch-lightning==1.9.4
[pip3] torch==2.0.0
[pip3] torchmetrics==0.11.3
[pip3] torchvision==0.15.1
[conda] numpy                     1.24.2                   pypi_0    pypi
[conda] pytorch-lightning         1.9.3                    pypi_0    pypi
[conda] torch                     1.13.1                   pypi_0    pypi
[conda] torchmetrics              0.11.3                   pypi_0    pypi
[conda] torchvision               0.14.1                   pypi_0    pypi

Quick note torch.inference_mode() is not supported with torch.compile and the performance improvements compile bring is a superset of inference mode

But your error is legit and might be worth opening on Github

Thank you for your response. I have posted the issue here: `torch.compile` + `torch.no_grad` not working for Mask R-CNN · Issue #7440 · pytorch/vision · GitHub