base.cpp
#include <torch/extension.h>
torch::Tensor base_forward(torch::Tensor x, torch:Tensor w, torch::Tensor b){
atuo o = w * x + b;
return o
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.def("forward", &base_forward, 'BASE forward');
}
module.py
import torch
import torch.nn as nn
import base_cpp
form torch.fx import Tracer
class M(nn,Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1. 1)
def forward(self, x):
x = self.conv(x)
out = base_cpp.forward(x, torch.randn(1), torch.randn(1))
return out
module = M()
nodes = Tracer().trace(module).nodes
Error
TypeError: forward():incompatible function arguments. The following argument types are supported:
1.(argo:at::Tensor, arg1: at:Tensor, arg2: at::Tensor)->at:Tensor
Invoked with: proxy(conv), tensor([2.0171]), tensor([-1.9950])