Passing stl container to torch::tensors

I am wondering what is the best solution to transform std::vector and others to a torch::tensor. Built-in arrays work like so

 float data[] = {1, 2, 3, 4, 5, 6, 7};
 torch::Tensor f = torch::from_blob(data, {2,3});

and apparently there are no boundary checkings, but thats a different topic.
I pass an std::vector like so

 std::vector<float> v{1, 2, 3, 4, 5, 6};
  1. Is this the most efficient way to do that? There are no overloads for std::vector, so I have to comply to the signature with the infamous void* pointer.

  2. Whats with the at::ArrayRef syntax? When should one use that and is a view like here

 auto img_ = cv::imread(argv[1], cv::IMREAD_COLOR);
  cv::Mat img(480, 640, CV_8UC3);
  cv::resize(img_, img, img.size(), 0, 0, cv::INTER_AREA);
  auto input_ = torch::tensor(at::ArrayRef<uint8_t>(img.data, img.rows * img.cols * 3)).view({img.rows, img.cols, 3});

generaly more efficient than the above approach?

  1. Whats the difference between
 auto t =   torch::Tensor f = torch::from_blob(std::data(v), {2, 3});

and

 auto t = torch::CUDA(torch::kFloat32).tensorFromBlob(std::data(v), {2, 3});

I thought, torch::tensor instances are moves to the GPU by specifying associated tensor options for passing to the respective factory function…

  1. What is the result of toTensor(), as for example in this line
 auto res = module->forward(inputs).toTensor();

I thought the result of the forward pass already is of type torch::tensor?

Unfortunately the C++ frontend Docs are not very clear on these matters yet

  1. vector.data() would convert std::vector to float*, which can be passed into torch::from_blob().
  2. at::ArrayRef shouldn’t make things faster, and auto input_ = torch::from_blob(img.data, {img.rows, img.cols, 3}) should just work.
  3. torch::from_blob() calls tensorFromBlob() internally (https://github.com/pytorch/pytorch/blob/075c7b1fef537205a0bdb8e45d2f800e6c024603/aten/src/ATen/templates/NativeFunctions.h#L28-L34). The only difference with those two statements is where the device and data type is specified (for torch::from_blob() we can pass a TensorOptions param to specify device and data type).
  4. If module is a JIT model, the result of module->forward(inputs) is a torch::jit::IValue, and toTensor() retrieves the torch::Tensor from it. If If module is a C++ frontned model, the result of module->forward(inputs) should already be a torch::Tensor and we don’t need toTensor().
4 Likes

very clear and precise, thank you very much!

What if I have a std::vector<torch::Tensor> whats the best way of converting it to a tensor?
meanwhile I’m using torch::stack({myvec}) to convert everyting into a Tensor. the torch::from_blob doesnt work with vectors of tensors it seems ( it doesnt issue an error, but it gives completely weird output ) (am I witnessing undefined behavior here? becasue it expects a *float ? and I’m sending *Tensors? )
Thanks a lot in advance

@Shisho_Sama
from_blob currently doesn’t support Tensor list/vec. torch::stack is the right one to use at this point.

1 Like