Hi all,
I try to write a C extension according to this: http://pytorch.org/tutorials/advanced/c_extension.html.
Everything is OK when I do the step 1. I can generate the _ext source folder and the .so file.
However, when I do the second step, something wrong happens.
My code is:
import torch
from _ext import my_lib
from torch.autograd import Function
from torch.nn import Module
import torch.nn as nn
from torch.autograd import Variable
class MyAddFunction(Function):
def forward(self, input1, input2):
output = torch.FloatTensor()
my_lib.my_lib_add_forward(input1, input2, output)
return output
def backward(self, grad_output):
grad_input = torch.FloatTensor()
my_lib.my_lib_add_backward(grad_output, grad_input)
return grad_input
class MyAddModule(Module):
def forward(self, input1, input2):
return MyAddFunction()(input1, input2)
class MyNetwork(nn.Module):
def __init__(self):
super(MyNetwork, self).__init__(
add = MyAddModule(),
)
def forward(self, input1, input2):
return self.add(input1, input2)
model = MyNetwork()
input1, input2 = Variable(torch.randn(5, 5)), Variable(torch.randn(5, 5))
print(model(input1, input2))
print(input1 + input2)
When I ran it, I got an error:
File "./test.py", line 40, in <module>
model = MyNetwork()
File "./test.py", line 32, in __init__
add = MyAddModule(),
TypeError: __init__() got an unexpected keyword argument 'add'
Could you tell me what’s wrong with it? Thanks!