Advice on clean C++ Interface involving a class with many members

Dear PyTorch Team,

thank you very much for the library, which is very pleasant to use and has great documentation!

I have already written multiple CUDA extensions for PyTorch and everything works well if the C++ function takes torch::Tensors or std::vector of such as input. However, I am now in the situation where I have an Index class that has ca. 20 members which are each a pointer to some (GPU) memory that is owned by a torch::Tensor. The question is how to pass this to C++ in the best way. I have a list of options and think that the last one is best, but I would be happy about ideas and opinions.

  1. Just pass all the torch::Tensor as arguments or a std::vector<torch::Tensor>.
    1.1 :heavy_plus_sign: Easy implementation, easy to understand.
    1.2 :heavy_minus_sign: Not flexible. If the indexing changes slightly this change has to be manually done for all functions.

  2. Implement the Index class in C++ and expose all necessary functionality through pybind11.
    2.1 :heavy_plus_sign: Full control in C++. Encapsulated.
    2.2 :heavy_minus_sign: No control in Python and thus no “hackability”. I don’t know if in this case memory management, reference counting, etc. works out of the box.

  3. Currently preferred. Implement Index in python (and in C++ but here with pointers only) and construct a new C++ Index each time a function/kernel is called.
    3.1 :heavy_plus_sign: Easy implementation, easy to understand. Can implement most things (e.g. generating the Index from data) in Python.
    3.2 :heavy_minus_sign: Maybe (very) slight runtime overhead? Have to manually make sure that pointers in C++ are valid. Other issues I forgot.

Below is the test code I wrote for the third option. I used only a small Index but how to extend it to more fields should be clear. I would be happy to hear any suggestions and please also tell me if this is solved/documented somewhere and I missed it.

Thank you

Cpp Code

#include <torch/extension.h>

// Just POD
class Index {
  int64_t n = 0;
  int64_t const *start = nullptr;
  int64_t const *size = nullptr;

  void print() const { std::cout << "Index: n=" << n << ", start=" << start << ", size=" << size << std::endl; }

auto dict_from_file(std::string const &fname) {
  std::ifstream input(fname, std::ios::binary);
  std::vector<char> bytes{std::istreambuf_iterator<char>(input), std::istreambuf_iterator<char>()};

  return torch::pickle_load(bytes).toGenericDict();

Index index_from_dict(std::unordered_map<std::string, torch::Tensor> const &dict) {
  Index index;
  index.n ="n").item().toLong();
  index.start ="start").data_ptr<int64_t>();
  index.size ="size").data_ptr<int64_t>();
  return index;

Index index_from_dict(c10::Dict<c10::IValue, c10::IValue> const &dict_in) {
  std::unordered_map<std::string, torch::Tensor> dict;
  for (auto const &elem : dict_in) dict[elem.key().toStringRef()] = elem.value().toTensor();
  return index_from_dict(dict);

void test_from_python(std::unordered_map<std::string, torch::Tensor> dict) {
  auto index = index_from_dict(dict);

void test_from_file(std::string fname) {
  auto dict = dict_from_file(fname); // Owns the dict
  auto index = index_from_dict(dict);

  m.def("test_from_python", &test_from_python, "Test passing index from python");
  m.def("test_from_file", &test_from_file, "Test loading index from file");

Python Code and Output