I created a c++ function and registered it as follow:
void to_csr_index(torch::Tensor &row_ptr, const torch::Tensor &edge_index_i, int64_t num_rows) {
auto *prow_ptr = row_ptr.data_ptr<int32_t>();
for (uint32_t i = 0; i < edge_index_i.size(0); ++i) {
prow_ptr[edge_index_i[i].item<int32_t>() + 1]++;
}
for (int i = 0; i < num_rows; i++) {
prow_ptr[i + 1] += prow_ptr[i];
}
}
TORCH_LIBRARY(my_ops, m) {
m.def("to_csr_index", to_csr_index);
}
Then I called it with:
num_nodes = data[AtomicDataDict.NODE_FEATURES_KEY].shape[0]
row_ptr = torch.zeros(num_nodes, dtype=edge_index_i.dtype)
torch.ops.my_ops.to_csr_index(row_ptr, edge_index_i, num_nodes)
Here is the error message:
File "/home/dym/code/torch-ff/Torch-FF/src/models/components/encoder/_edge.py", line 71, in to_csr
torch.ops.my_ops.to_csr_index(row_ptr, edge_index_i, num_nodes)
File "/home/dym/.conda/envs/torch2.1.0/lib/python3.10/site-packages/torch/_ops.py", line 692, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: expected scalar type Int but found Long
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
How to fix it?