Dear PyTorch Team,
thank you very much for the library, which is very pleasant to use and has great documentation!
I have already written multiple CUDA extensions for PyTorch and everything works well if the C++ function takes torch::Tensors or std::vector of such as input. However, I am now in the situation where I have an Index
class that has ca. 20 members which are each a pointer to some (GPU) memory that is owned by a torch::Tensor
. The question is how to pass this to C++ in the best way. I have a list of options and think that the last one is best, but I would be happy about ideas and opinions.
-
Just pass all the
torch::Tensor
as arguments or astd::vector<torch::Tensor>
.
1.1Easy implementation, easy to understand.
1.2Not flexible. If the indexing changes slightly this change has to be manually done for all functions.
-
Implement the
Index
class in C++ and expose all necessary functionality through pybind11.
2.1Full control in C++. Encapsulated.
2.2No control in Python and thus no “hackability”. I don’t know if in this case memory management, reference counting, etc. works out of the box.
-
Currently preferred. Implement
Index
in python (and in C++ but here with pointers only) and construct a new C++Index
each time a function/kernel is called.
3.1Easy implementation, easy to understand. Can implement most things (e.g. generating the Index from data) in Python.
3.2Maybe (very) slight runtime overhead? Have to manually make sure that pointers in C++ are valid. Other issues I forgot.
Below is the test code I wrote for the third option. I used only a small Index
but how to extend it to more fields should be clear. I would be happy to hear any suggestions and please also tell me if this is solved/documented somewhere and I missed it.
Thank you
Lukas
Cpp Code
#include <torch/extension.h>
// Just POD
class Index {
public:
int64_t n = 0;
int64_t const *start = nullptr;
int64_t const *size = nullptr;
void print() const { std::cout << "Index: n=" << n << ", start=" << start << ", size=" << size << std::endl; }
};
auto dict_from_file(std::string const &fname) {
std::ifstream input(fname, std::ios::binary);
std::vector<char> bytes{std::istreambuf_iterator<char>(input), std::istreambuf_iterator<char>()};
input.close();
return torch::pickle_load(bytes).toGenericDict();
}
Index index_from_dict(std::unordered_map<std::string, torch::Tensor> const &dict) {
Index index;
index.n = dict.at("n").item().toLong();
index.start = dict.at("start").data_ptr<int64_t>();
index.size = dict.at("size").data_ptr<int64_t>();
return index;
}
Index index_from_dict(c10::Dict<c10::IValue, c10::IValue> const &dict_in) {
std::unordered_map<std::string, torch::Tensor> dict;
for (auto const &elem : dict_in) dict[elem.key().toStringRef()] = elem.value().toTensor();
return index_from_dict(dict);
}
void test_from_python(std::unordered_map<std::string, torch::Tensor> dict) {
auto index = index_from_dict(dict);
index.print();
}
void test_from_file(std::string fname) {
auto dict = dict_from_file(fname); // Owns the dict
auto index = index_from_dict(dict);
index.print();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("test_from_python", &test_from_python, "Test passing index from python");
m.def("test_from_file", &test_from_file, "Test loading index from file");
}