I am trying to pass a TensorList to a C++ extension, as done for example in torch.nn.utils.rnn.pad_sequence. I am new to extending pytorch and have been following this tutorial.
A very basic function to test if I can pass a list of tensors could be
#include <torch/extension.h>
#include <pybind11/pybind11.h>
bool test_tensorlist(torch::TensorList x) {
return true
}
PYBIND11_MODULE(tensorlist, m) {
m.doc() = "Test how to pass tensor list to C++ extension.";
m.def("test_tensorlist", &test_tensorlist, "Pass list of tensors (as torch::TensorList).")
}
Unfortunately, attempting to call this in python results in
TypeError: incompatible function arguments. The following argument types are supported:
1. (arg0: c10::ArrayRef<at::Tensor>) -> bool
Invoked with: [tensor([1]), tensor([2])]
I get the exact same result if I use at::TensorList
or c10::ArrayRef<at::Tensor>
. Passing a single tensor via torch::Tensor
or a list of integers via c10::ArrayRef<int64_t>
works like charm, though.
Any ideas what I might be doing wrong?
Edits: formatting