Hello everyone,
I am writing a simple python c++ extension for Linear. But I encountered a problem:
My C++ extension code is as follows:
#include<vector>
#include<torch/extension.h>
torch::Tensor d_sigmoid(const torch::Tensor &z)
{
auto s = torch::sigmoid(z);
return (1 - s) * s;
}
torch::Tensor dense_forward(const torch::Tensor &input,const torch::Tensor &weights,const torch::Tensor &bias)
{
auto output = torch::mm(input,weights.t()) + bias; //mm(input,weights.t()) eg. input shape 1*2; weights shape 4*2; output shape 1*4; bias 1*4
output = torch::sigmoid(output);
return output;
}
std::vector<torch::Tensor> dense_backward(const torch::Tensor &grad_output, const torch::Tensor &input, const torch::Tensor &output,const torch::Tensor &weights, const torch::Tensor &bias)
{
auto output_d_sigmoid = d_sigmoid(output);
auto grad = grad_output * output_d_sigmoid;
auto grad_weights = torch::mm(grad.t(),input); // 本层的权重的梯度
auto grad_bias = grad.sum(0,/*keepdim=*/false); //偏置层梯度
auto grad_input = torch::mm(grad,weights); // 传给前一层的梯度
return {grad_input, grad_weights, grad_bias};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &dense_forward, "dense forward");
m.def("backward", &dense_backward,"dense backward");
}
My python code is as follows:
import torch
import torch.nn as nn
from torch.autograd import Variable
import math
import dnn
class densefunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weights, bias):
outputs = dnn.forward(x, weights, bias)
ctx.save_for_backward(x, weights, bias, outputs)
return outputs
@staticmethod
def backward(ctx,grad_output):
x, weight, bias, output = ctx.saved_tensors
grad_input, grad_weight, grad_bias = dnn.backward(grad_output, x, output, weight, bias)
return grad_input, grad_weight, grad_bias
class dense(nn.Module):
def __init__ (self,input_features, output_features):
super(dense,self).__init__()
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(torch.empty(output_features,input_features))
self.bias = nn.Parameter(torch.empty(output_features))
self.weight.data.uniform_(-0.1,0.1)
self.bias.data.uniform_(-0.1,0.1)
def forward(self,x):
return densefunction.apply(x,self.weight,self.bias)
if __name__ == '__main__':
x = torch.randn((4,5))
dnn = dense(5,3)
print(dnn(x))
When running the code,I encountered a problem:
Traceback (most recent call last):
File "main.py", line 41, in <module>
print(dnn(x))
File "/data/zh/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "main.py", line 34, in forward
return densefunction.apply(x,self.weight,self.bias)
File "main.py", line 12, in forward
outputs = dnn.forward(x, weights, bias)
TypeError: forward() takes 2 positional arguments but 4 were given
Hope someone help me to solve this problem. Thanks!