Simplified example show below:
import torch def test_func(h: int, w: int) -> int: return h * w class TestModule(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: height, width = x.shape[2:] # Handle torch.fx if isinstance(x, torch.fx.proxy.Proxy): height, width = 224, 224 hxw = test_func(height, width) return x * hxw # Seems to work okay torch.fx.symbolic_trace(TestModule())
Will this solution cause any issues? It seemed to work when I tested it, but I may be unaware of potential issues with this fix.
I want to deploy this solution in custom GoogleNet and Inception models, so I can use
torchvision.models.feature_extraction.create_feature_extractor on them.