Seeking insights: Strange torch.compile() behavior with RegionViT in inference_mode()

I’ve encountered an interesting issue while using torch.compile() for inference with a RegionViT model and would appreciate help debugging it.

I’ve filed a detailed bug report here: TypeError when using torch.compile with RegionViT under torch.inference_mode() · Issue #146780 · pytorch/pytorch · GitHub

My main goal is to understand which part of the code is causing the issue and how to fix it in my implementation.

Reproduction Steps

Install required packages:

pip install birder torch torchvision torchaudio

Run the following minimal example:

from birder.model_registry import registry
import torch

net = registry.net_factory("regionvit_t", input_channels=3, num_classes=1000)
net.to(torch.device("cuda"))
net.eval()
net = torch.compile(net)
with torch.inference_mode():
    net(torch.rand(1, 3, 256, 256, device=torch.device("cuda")))

Debugging Attempts

I’ve added print statements to locate the failure point and observed some unexpected behavior. In the main forward pass (reference: [birder/net/regionvit.py · c1f5f9adb6b552af7f1cd2b9603e8a5abf8d3340 · birder / birder · GitLab)):

def forward(self, cls_tokens: torch.Tensor, patch_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    (cls_tokens, patch_tokens) = self.proj(cls_tokens, patch_tokens)
    (out, mask, p_r, p_b, B, C, H, W) = convert_to_flatten_layout(cls_tokens, patch_tokens, self.ws[0])
    for blk in self.blocks:
        # print("A")
        out = blk((out, B), mask)
        # print("B")

    # print("C")
    (cls_tokens, patch_tokens) = convert_to_spatial_layout(out, B, C, H, W, self.ws, mask, p_r, p_b)

    return (cls_tokens, patch_tokens)

Interestingly:

  • Adding print statements at points A or B makes the code succeed
  • Adding print statements at point C still results in failure
  • Any print statement inside the blk forward pass also leads to success

Does anyone have any ideas on how to zero in on the issue? Any insights or suggestions would be greatly appreciated.