Hi,
it seems that I am encountering a bug from torch.fx
.
It is wrongly detecting a sigmoid function call as a method call:
class ConvAutoencoder(nn.Module):
def __init__(self):
super(ConvAutoencoder, self).__init__()
#Encoder
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 4, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
#Decoder
self.t_conv1 = nn.ConvTranspose2d(4, 16, 2, stride=2)
self.t_conv2 = nn.ConvTranspose2d(16, 3, 2, stride=2)
self.sigmoid = nn.Sigmoid
def forward(self, x):
feat1 = F.relu(self.conv1(x))
x = self.pool(feat1)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = F.relu(self.t_conv1(x))
x = F.sigmoid(torch.cat([self.t_conv2(x), feat1], 1))
return x
model = torch.fx.symbolic_trace(ConvAutoencoder())
for node in model.graph.nodes:
print(f'{node.op} name: {node.name} target: {node.target} ({type(node.target)}) args: {node.args} kwargs: {node.kwargs}')
output:
placeholder name: x target: x (<class 'str'>) args: () kwargs: {}
call_module name: conv1 target: conv1 (<class 'str'>) args: (x,) kwargs: {}
call_function name: relu_1 target: <function relu at 0x7f3ccc0104d0> (<class 'function'>) args: (conv1,) kwargs: {'inplace': False}
call_module name: pool target: pool (<class 'str'>) args: (relu_1,) kwargs: {}
call_module name: conv2 target: conv2 (<class 'str'>) args: (pool,) kwargs: {}
call_function name: relu_2 target: <function relu at 0x7f3ccc0104d0> (<class 'function'>) args: (conv2,) kwargs: {'inplace': False}
call_module name: pool_1 target: pool (<class 'str'>) args: (relu_2,) kwargs: {}
call_module name: t_conv1 target: t_conv1 (<class 'str'>) args: (pool_1,) kwargs: {}
call_function name: relu_3 target: <function relu at 0x7f3ccc0104d0> (<class 'function'>) args: (t_conv1,) kwargs: {'inplace': False}
call_module name: t_conv2 target: t_conv2 (<class 'str'>) args: (relu_3,) kwargs: {}
call_function name: cat_1 target: <built-in method cat of type object at 0x7f3ce1503640> (<class 'builtin_function_or_method'>) args: ([t_conv2, relu_1], 1) kwargs: {}
call_method name: sigmoid_1 target: sigmoid (<class 'str'>) args: (cat_1,) kwargs: {}
output name: output target: output (<class 'str'>) args: (sigmoid_1,) kwargs: {}
What is happening?
Is this really a bug?
It is currently a problem for me as I am trying to rebuild a full model with some variations using proxy calls, following this discussion.
While I can provide a quick fix special to the sigmoid case here, I am not sure if this is only related to sigmoid or other functions and if this behaviour is constant.
Thank you in advance for your interest.