Is there a way to convert Proxy to base type?
such as :
import torch
import torch.nn as nn
from torch.fx import symbolic_trace
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = x.size(1)
a = torch.randn(x)
return a
m = M()
symbolic_trace(m)
while the type of x is Proxy, and the function randn can’t use Proxy.
import torch
import torch.nn as nn
from torch.fx import symbolic_trace
class M(nn.Module):
def __init__(self):
super().__init__()
self.n = 0
def forward(self, x):
if not isinstance(x, torch.fx.Proxy):
self.n = x.size(1)
# x = x.size(1)
a = torch.randn(self.n)
return a
m = M()
data = torch.randn(1, 3, 224 ,224)
m(data) # need forward this module first.
symbolic_trace(m)