I’ve encountered an interesting issue while using torch.compile()
for inference with a RegionViT model and would appreciate help debugging it.
I’ve filed a detailed bug report here: TypeError when using torch.compile with RegionViT under torch.inference_mode() · Issue #146780 · pytorch/pytorch · GitHub
My main goal is to understand which part of the code is causing the issue and how to fix it in my implementation.
Reproduction Steps
Install required packages:
pip install birder torch torchvision torchaudio
Run the following minimal example:
from birder.model_registry import registry
import torch
net = registry.net_factory("regionvit_t", input_channels=3, num_classes=1000)
net.to(torch.device("cuda"))
net.eval()
net = torch.compile(net)
with torch.inference_mode():
net(torch.rand(1, 3, 256, 256, device=torch.device("cuda")))
Debugging Attempts
I’ve added print statements to locate the failure point and observed some unexpected behavior. In the main forward pass (reference: [birder/net/regionvit.py · c1f5f9adb6b552af7f1cd2b9603e8a5abf8d3340 · birder / birder · GitLab)):
def forward(self, cls_tokens: torch.Tensor, patch_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
(cls_tokens, patch_tokens) = self.proj(cls_tokens, patch_tokens)
(out, mask, p_r, p_b, B, C, H, W) = convert_to_flatten_layout(cls_tokens, patch_tokens, self.ws[0])
for blk in self.blocks:
# print("A")
out = blk((out, B), mask)
# print("B")
# print("C")
(cls_tokens, patch_tokens) = convert_to_spatial_layout(out, B, C, H, W, self.ws, mask, p_r, p_b)
return (cls_tokens, patch_tokens)
Interestingly:
- Adding print statements at points A or B makes the code succeed
- Adding print statements at point C still results in failure
- Any print statement inside the
blk
forward pass also leads to success
Does anyone have any ideas on how to zero in on the issue? Any insights or suggestions would be greatly appreciated.