Hi
I’m trying to write a small extension to PyTorch in C++. I’ve tried to look at the examples provided and see that the syntax is quite similar to what the python syntax is.
Part - 1
Here is my code for the extension
#include <torch/extension.h>
#include <vector>
#include <iostream>
torch::Tensor compute_grad(torch::Tensor in, torch::Tensor out)
{
std::cout << "Input is:" << in << "\n" << "Output is: " << out << "\n";
auto b = out.sum();
std::cout << "Out_sum:" << "\t" << b << std::endl;
b.backward();
std::cout << __LINE__ << std::endl;
return in.grad();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("compute_grad", &compute_grad, "Compute Gradient");
}
and I call it with
import torch
import DenseGrad
if __name__ == "__main__":
x = (2 * torch.ones(2, 2)).requires_grad_()
y = x ** 3
print(DenseGrad.compute_grad(x, y))
For some reason, the backward()
call is getting stuck and the only way to kill it is through top
. Can someone tell me what I’m doing wrong?