Simplified example show below:
import torch
def test_func(h: int, w: int) -> int:
return h * w
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
height, width = x.shape[2:]
# Handle torch.fx
if isinstance(x, torch.fx.proxy.Proxy):
height, width = 224, 224
hxw = test_func(height, width)
return x * hxw
# Seems to work okay
torch.fx.symbolic_trace(TestModule())
Will this solution cause any issues? It seemed to work when I tested it, but I may be unaware of potential issues with this fix.
I want to deploy this solution in custom GoogleNet and Inception models, so I can use torchvision.models.feature_extraction.get_graph_node_names
& torchvision.models.feature_extraction.create_feature_extractor
on them.