Error : When torch.utils.cpp_extension and torch.fx are used together

base.cpp

#include <torch/extension.h>

torch::Tensor base_forward(torch::Tensor x, torch:Tensor w, torch::Tensor b){
    atuo o = w * x + b;
    return o
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("forward", &base_forward, 'BASE forward');
}

module.py

import torch
import torch.nn as nn
import base_cpp
form torch.fx import Tracer

class M(nn,Module):
    def __init__(self):
        super().__init__()
            self.conv = nn.Conv2d(1, 1. 1)

    def forward(self, x):
        x = self.conv(x)
        out = base_cpp.forward(x, torch.randn(1), torch.randn(1))
        return out

module = M()

nodes = Tracer().trace(module).nodes

Error

TypeError: forward():incompatible function arguments. The following argument types are supported:
1.(argo:at::Tensor, arg1: at:Tensor, arg2: at::Tensor)->at:Tensor

Invoked with: proxy(conv), tensor([2.0171]), tensor([-1.9950])

I think you might try your luck with torch.fx.wrap described in the non-torch functions section of the fx documentation.

Best regards

Thomas

P.S.: So the more detailed explanation for the error is that during tracing, the function will be called not with a tensor x but with a proxy object that causes functions to amend to the trace. wrap is a mechanism make function calls with these proxy objects insert a record of their invocation.

Dear friend,
I realy appreciate your help. However,I still have a question about how wrap custom methods.

temp.py

import torch
import base_cpp
from torch.fx import symbolic_trace
import torch.fx

def pybase_cpp(x):
    return base_cpp.forward(x, torch.randn(1), torch.randn(1))  # base_cpp.forward no __code__

torch.fx.wrap('base_cpp.forward')
pybase_cpp(torch.randn(1))
traced = symbolic_trace(pybase_cpp)

error

KeyError: 'base_cpp.forward'

So. Is there any good way to wrap custom methods. Thanks very much.

I think you need to put it in the global namespace (or at least that is a feasible workaround):

import torch
import torch.utils.cpp_extension

cpp_source = """
#include <torch/extension.h>

torch::Tensor base_forward(torch::Tensor x, torch::Tensor w, torch::Tensor b){
    auto o = w * x + b;
    return o;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("forward", &base_forward, "BASE forward");
}
"""

import torch
import torch.nn as nn

base_cpp = torch.utils.cpp_extension.load_inline("base_cpp", cpp_source)

from torch.fx import Tracer

base_cpp_forward = base_cpp.forward
torch.fx.wrap('base_cpp_forward')

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 1)

    def forward(self, x):
        x = self.conv(x)
        out = base_cpp_forward(x, torch.randn(1), torch.randn(1))
        return out

module = M()

traced = torch.fx.symbolic_trace(module)
print (traced.code)

P.S.: Kudos for posting code, but it seems to have a lot of small typos.

1 Like

It’s awesome!!!

Thanks, my friend. :+1: :+1: :+1:

1 Like

Hi, dear friend,I try to convert this .py file to the .so file

when I call this .so file

the torch.fx.wrap function seem unable to call

a.cpp : this is a ops file

#include <torch/extension.h>

torch::Tensor base_cpp_forward(torch::Tensor x, torch::Tensor w, torch::Tensor b){
    auto o = w * x + b;
    return o;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("forward", &base_cpp_forward, "BASE forward");
}

m.py : this define the model with ops

import torch
from a import *

torch.fx.wrap('base_cpp_forward')

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 1)

    def forward(self, x):
        x = self.conv(x)
        out = base_cpp_forward(x, torch.randn(1), torch.randn(1))
        return out

setup.py : convert the .py to .so

from distutils.core import setup
from Cython.Build import cythonize

setup(name='m', ext_modules = cythonize(['m.py']))

t.py : test the m.so file

from m import M  # this m is .so file
from torchvision.models.feature_extraction import get_graph_node_names

m = M()
train_node = get_graph_node_names(m)

the error: from : _symbolic_trace.py

assert isinstance(fn, FunctionType) = False

when I ignore this order

TypeError: forward():incompatible function arguments. The following argument types are supported:
1.(argo:at::Tensor, arg1: at:Tensor, arg2: at::Tensor)->at:Tensor

Invoked with: proxy(conv), tensor([2.0171]), tensor([-1.9950])

How to deal it, please help me , thx.